diff --git a/.gitignore b/.gitignore
index 5573d2b9..2579acdd 100644
--- a/.gitignore
+++ b/.gitignore
@@ -143,3 +143,6 @@ docs/generated/
# vscode
.vscode/
+
+# nwb cahce
+nwb-cache/
diff --git a/docs/developers_notes/basis_module.md b/docs/developers_notes/01-basis_module.md
similarity index 99%
rename from docs/developers_notes/basis_module.md
rename to docs/developers_notes/01-basis_module.md
index a96d875a..ebb2ace5 100644
--- a/docs/developers_notes/basis_module.md
+++ b/docs/developers_notes/01-basis_module.md
@@ -1,4 +1,4 @@
-# The Basis Module
+# The `basis` Module
## Introduction
diff --git a/docs/developers_notes/02-base_class.md b/docs/developers_notes/02-base_class.md
new file mode 100644
index 00000000..521ff992
--- /dev/null
+++ b/docs/developers_notes/02-base_class.md
@@ -0,0 +1,100 @@
+# The `base_class` Module
+
+## Introduction
+
+The `base_class` module introduces the `Base` class and abstract classes defining broad model categories. These abstract classes **must** inherit from `Base`.
+
+The `Base` class is envisioned as the foundational component for any object type (e.g., regression, dimensionality reduction, clustering, observation models, regularizers etc.). In contrast, abstract classes derived from `Base` define overarching object categories (e.g., `base_class.BaseRegressor` is building block for GLMs, GAMS, etc. while `observation_models.Observations` is the building block for the Poisson observations, Gamma observations, ... etc.).
+
+Designed to be compatible with the `scikit-learn` API, the class structure aims to facilitate access to `scikit-learn`'s robust pipeline and cross-validation modules. This is achieved while leveraging the accelerated computational capabilities of `jax` and `jaxopt` in the backend, which is essential for analyzing extensive neural recordings and fitting large models.
+
+Below a scheme of how we envision the architecture of the `nemos` models.
+
+```
+Abstract Class Base
+│
+├─ Abstract Subclass BaseRegressor
+│ │
+│ └─ Concrete Subclass GLM
+│ │
+│ └─ Concrete Subclass RecurrentGLM
+│
+├─ Abstract Subclass BaseManifold *(not implemented yet)
+│ │
+│ ...
+│
+├─ Abstract Subclass Regularizer
+│ │
+│ ├─ Concrete Subclass UnRegularized
+│ │
+│ ├─ Concrete Subclass Ridge
+│ ...
+│
+├─ Abstract Subclass Observations
+│ │
+│ ├─ Concrete Subclass PoissonObservations
+│ │
+│ ├─ Concrete Subclass GammaObservations *(not implemented yet)
+│ ...
+│
+...
+```
+
+!!! Example
+ The current package version includes a concrete class named `nemos.glm.GLM`. This class inherits from `BaseRegressor`, which in turn inherits `Base`, since it falls under the " GLM regression" category.
+ As any `BaseRegressor`, it **must** implement the `fit`, `score`, `predict`, and `simulate` methods.
+
+
+## The Class `model_base.Base`
+
+The `Base` class aligns with the `scikit-learn` API for `base.BaseEstimator`. This alignment is achieved by implementing the `get_params` and `set_params` methods, essential for `scikit-learn` compatibility and foundational for all model implementations. Additionally, the class provides auxiliary helper methods to identify available computational devices (such as GPUs and TPUs) and to facilitate data transfer to these devices.
+
+For a detailed understanding, consult the [`scikit-learn` API Reference](https://scikit-learn.org/stable/modules/classes.html) and [`BaseEstimator`](https://scikit-learn.org/stable/modules/generated/sklearn.base.BaseEstimator.html).
+
+!!! Note
+ We've intentionally omitted the `get_metadata_routing` method. Given its current experimental status and its lack of relevance to the `GLM` class, this method was excluded. Should future needs arise around parameter routing, consider directly inheriting from `sklearn.BaseEstimator`. More information can be found [here](https://scikit-learn.org/stable/metadata_routing.html#metadata-routing).
+
+### Public methods
+
+- **`get_params`**: The `get_params` method retrieves parameters set during model instance initialization. Opting for a deep inspection allows the method to assess nested object parameters, resulting in a comprehensive parameter dictionary.
+- **`set_params`**: The `set_params` method offers a mechanism to adjust or set an estimator's parameters. It's versatile, accommodating both individual estimators and more complex nested structures like pipelines. Feeding an unrecognized parameter will raise a `ValueError`.
+
+## The Abstract Class `model_base.BaseRegressor`
+
+`BaseRegressor` is an abstract class that inherits from `Base`, stipulating the implementation of abstract methods: `fit`, `predict`, `score`, and `simulate`. This ensures seamless assimilation with `scikit-learn` pipelines and cross-validation procedures.
+
+### Abstract Methods
+
+For subclasses derived from `BaseRegressor` to function correctly, they must implement the following:
+
+1. `fit`: Adapt the model using input data `X` and corresponding observations `y`.
+2. `predict`: Provide predictions based on the trained model and input data `X`.
+3. `score`: Score the accuracy of model predictions using input data `X` against the actual observations `y`.
+4. `simulate`: Simulate data based on the trained regression model.
+
+### Public Methods
+
+To ensure the consistency and conformity of input data, the `BaseRegressor` introduces two public preprocessing methods:
+
+1. `preprocess_fit`: Assesses and converts the input for the `fit` method into the desired `jax.ndarray` format. If necessary, this method can initialize model parameters using default values.
+2. `preprocess_simulate`: Validates and converts inputs for the `simulate` method. This method confirms the integrity of the feedforward input and, when provided, the initial values for feedback.
+
+### Auxiliary Methods
+
+Moreover, `BaseRegressor` incorporates auxiliary methods such as `_convert_to_jnp_ndarray`, `_has_invalid_entry`
+and a number of other methods for checking input consistency.
+
+!!! Tip
+ Deciding between concrete and abstract methods in a superclass can be nuanced. As a general guideline: any method that's expected in all subclasses and isn't subclass-specific should be concretely implemented in the superclass. Conversely, methods essential for a subclass's expected behavior, but vary based on the subclass, should be abstract in the superclass. For instance, compatibility with the `sklearn.cross_validation` module demands `score`, `fit`, `get_params`, and `set_params` methods. Given their specificity to individual models, `score` and `fit` are abstract in `BaseRegressor`. Conversely, as `get_params` and `set_params` are consistent across model classes, they're inherited from `Base`. This approach typifies our general implementation strategy. However, it's important to note that while these are sound guidelines, exceptions exist based on various factors like future extensibility, clarity, and maintainability.
+
+
+## Contributor Guidelines
+
+### Implementing Model Subclasses
+
+When devising a new model subclass based on the `BaseRegressor` abstract class, adhere to the subsequent guidelines:
+
+- **Must** inherit the `BaseRegressor` abstract superclass.
+- **Must** realize the abstract methods: `fit`, `predict`, `score`, and `simulate`.
+- **Should not** overwrite the `get_params` and `set_params` methods, inherited from `Base`.
+- **May** introduce auxiliary methods such as `_convert_to_jnp_ndarray` for added utility.
diff --git a/docs/developers_notes/03-observation_models.md b/docs/developers_notes/03-observation_models.md
new file mode 100644
index 00000000..b7429bcd
--- /dev/null
+++ b/docs/developers_notes/03-observation_models.md
@@ -0,0 +1,64 @@
+# The `observation_models` Module
+
+## Introduction
+
+The `observation_models` module provides objects representing the observations of GLM-like models.
+
+The abstract class `Observations` defines the structure of the subclasses which specify observation types, such as Poisson, Gamma, etc. These objects serve as attributes of the [`nemos.glm.GLM`](../05-glm/#the-concrete-class-glm) class, equipping the GLM with a negative log-likelihood. This is used to define the optimization objective, the deviance which measures model fit quality, and the emission of new observations, for simulating new data.
+
+## The Abstract class `Observations`
+
+The abstract class `Observations` is the backbone of any observation model. Any class inheriting `Observations` must reimplement the `negative_log_likelihood`, `sample_generator`, `residual_deviance`, and `estimate_scale` methods.
+
+### Abstract Methods
+
+For subclasses derived from `Observations` to function correctly, they must implement the following:
+
+- **negative_log_likelihood**: Computes the negative-log likelihood of the model up to a normalization constant. This method is usually part of the objective function used to learn GLM parameters.
+
+- **sample_generator**: Returns the random emission probability function. This typically invokes `jax.random` emission probability, provided some sufficient statistics[^1]. For distributions in the exponential family, the sufficient statistics are the canonical parameter and the scale. In GLMs, the canonical parameter is entirely specified by the model's weights, while the scale is either fixed (i.e., Poisson) or needs to be estimated (i.e., Gamma).
+
+- **residual_deviance**: Computes the residual deviance based on the model's estimated rates and observations.
+
+- **estimate_scale**: A method for estimating the scale parameter of the model.
+
+### Public Methods
+
+- **pseudo_r2**: Method for computing the pseudo-$R^2$ of the model based on the residual deviance. There is no consensus definition for the pseudo-$R^2$, what we used here is the definition by Cohen at al. 2003[^2].
+
+
+### Auxiliary Methods
+
+- **_check_inverse_link_function**: Check that the provided link function is a `Callable` of the `jax` namespace.
+
+## Concrete `PoissonObservations` class
+
+The `PoissonObservations` class extends the abstract `Observations` class to provide functionalities specific to the Poisson observation model. It is designed for modeling observed spike counts based on a Poisson distribution with a given rate.
+
+### Overridden Methods
+
+- **negative_log_likelihood**: This method computes the Poisson negative log-likelihood of the predicted rates for the observed spike counts.
+
+- **sample_generator**: Generates random numbers from a Poisson distribution based on the given `predicted_rate`.
+
+- **residual_deviance**: Calculates the residual deviance for a Poisson model.
+
+- **estimate_scale**: Assigns a fixed value of 1 to the scale parameter of the Poisson model since Poisson distribution has a fixed scale.
+
+## Contributor Guidelines
+
+To implement an observation model class you
+
+- **Must** inherit from `Observations`
+
+- **Must** provide a concrete implementation of `negative_log_likelihood`, `sample_generator`, `residual_deviance`, and `estimate_scale`.
+
+- **Should not** reimplement the `pseudo_r2` method as well as the `_check_inverse_link_function` auxiliary method.
+
+[^1]:
+ In statistics, a statistic is sufficient with respect to a statistical model and its associated unknown parameters if "no other statistic that can be calculated from the same sample provides any additional information as to the value of the parameters", adapted from Fisher R. A.
+ 1922. On the mathematical foundations of theoretical statistics. *Philosophical Transactions of the Royal Society of London. Series A, Containing Papers of a Mathematical or Physical Character* 222:309–368. http://doi.org/10.1098/rsta.1922.0009.
+[^2]:
+ Jacob Cohen, Patricia Cohen, Steven G. West, Leona S. Aiken.
+ *Applied Multiple Regression/Correlation Analysis for the Behavioral Sciences*.
+ 3rd edition. Routledge, 2002. p.502. ISBN 978-0-8058-2223-6. (May 2012)
diff --git a/docs/developers_notes/04-regularizer.md b/docs/developers_notes/04-regularizer.md
new file mode 100644
index 00000000..3026c162
--- /dev/null
+++ b/docs/developers_notes/04-regularizer.md
@@ -0,0 +1,166 @@
+# The `regularizer` Module
+
+## Introduction
+
+The `regularizer` module introduces an archetype class `Regularizer` which provides the structural components for each concrete sub-class.
+
+Objects of type `Regularizer` provide methods to define a regularized optimization objective, and instantiate a solver for it. These objects serve as attribute of the [`nemos.glm.GLM`](../05-glm/#the-concrete-class-glm), equipping the glm with a solver for learning model parameters.
+
+Solvers are typically optimizers from the `jaxopt` package, but in principle they could be custom optimization routines as long as they respect the `jaxopt` api (i.e., have a `run` and `update` method with the appropriate input/output types).
+We choose to rely on `jaxopt` because it provides a comprehensive set of robust, GPU accelerated, batchable and differentiable optimizers in JAX, that are highly customizable.
+
+Each `Regularizer` object defines a set of allowed optimizers, which in turn depends on the loss function characteristics (smooth vs non-smooth) and/or the optimization type (constrained, un-constrained, batched, etc.).
+
+```
+Abstract Class Regularizer
+|
+├─ Concrete Class UnRegularized
+|
+├─ Concrete Class Ridge
+|
+└─ Abstract Class ProximalGradientRegularizer
+ |
+ ├─ Concrete Class Lasso
+ |
+ └─ Concrete Class GroupLasso
+```
+
+!!! note
+ If we need advanced adaptive optimizers (e.g., Adam, LAMB etc.) in the future, we should consider adding [`Optax`](https://optax.readthedocs.io/en/latest/) as a dependency, which is compatible with `jaxopt`, see [here](https://jaxopt.github.io/stable/_autosummary/jaxopt.OptaxSolver.html#jaxopt.OptaxSolver).
+
+## The Abstract Class `Regularizer`
+
+The abstract class `Regularizer` enforces the implementation of the `instantiate_solver` method on any concrete realization of a `Regularizer` object. `Regularizer` objects are equipped with a method for instantiating a solver runner with the appropriately regularized loss function, i.e., a function that receives as input the initial parameters, the endogenous and the exogenous variables, and outputs the optimization results.
+
+Additionally, the class provides auxiliary methods for checking that the solver and loss function specifications are valid.
+
+### Public Methods
+
+- **`instantiate_solver`**: Instantiate a solver runner for a provided loss function, configure and return a `solver_run` callable. The loss function must be of type `Callable`.
+
+### Auxiliary Methods
+
+- **`_check_solver`**: This method ensures that the provided solver name is in the list of allowed solvers for the specific `Regularizer` object. This is crucial for maintaining consistency and correctness in the solver's operation.
+
+- **`_check_solver_kwargs`**: This method checks if the provided keyword arguments are valid for the specified solver. This helps in catching and preventing potential errors in solver configuration.
+
+## The `UnRegularized` Class
+
+The `UnRegularized` class extends the base `Regularizer` class and is designed specifically for optimizing unregularized models. This means that the solver instantiated by this class does not add any regularization penalty to the loss function during the optimization process.
+
+### Attributes
+
+- **`allowed_solvers`**: A list of string identifiers for the optimization solvers that can be used with this regularizer class. The optimization methods listed here are specifically suitable for unregularized optimization problems.
+
+### Methods
+
+- **`__init__`**: The constructor method for this class which initializes a new `UnRegularized` object. It accepts the name of the solver algorithm to use (`solver_name`) and an optional dictionary of additional keyword arguments (`solver_kwargs`) for the solver.
+
+- **`instantiate_solver`**: A method which prepares and returns a runner function for the specified loss function. This method ensures that the loss function is callable and prepares the necessary keyword arguments for calling the `get_runner` method from the base `Regularizer` class.
+
+### Example Usage
+
+```python
+unregularized = UnRegularized(solver_name="GradientDescent")
+runner = unregularized.instantiate_solver(loss_function)
+optim_results = runner(init_params, exog_vars, endog_vars)
+```
+
+## The `Ridge` Class
+
+The `Ridge` class extends the `Regularizer` class to handle optimization problems with Ridge regularization. Ridge regularization adds a penalty to the loss function, proportional to the sum of squares of the model parameters, to prevent overfitting and stabilize the optimization.
+
+### Attributes
+
+- **`allowed_solvers`**: A list containing string identifiers of optimization solvers compatible with Ridge regularization.
+
+- **`regularizer_strength`**: A floating-point value determining the strength of the Ridge regularization. Higher values correspond to stronger regularization which tends to drive the model parameters towards zero.
+
+### Methods
+
+- **`__init__`**: The constructor method for the `Ridge` class. It accepts the name of the solver algorithm (`solver_name`), an optional dictionary of additional keyword arguments (`solver_kwargs`) for the solver, and the regularization strength (`regularizer_strength`).
+
+- **`penalization`**: A method to compute the Ridge regularization penalty for a given set of model parameters.
+
+- **`instantiate_solver`**: A method that prepares and returns a runner function with a penalized loss function for Ridge regularization. This method modifies the original loss function to include the Ridge penalty, ensures the loss function is callable, and prepares the necessary keyword arguments for calling the `get_runner` method from the base `Regularizer` class.
+
+### Example Usage
+
+```python
+ridge = Ridge(solver_name="LBFGS", regularizer_strength=1.0)
+runner = ridge.instantiate_solver(loss_function)
+optim_results = runner(init_params, exog_vars, endog_vars)
+```
+
+## `ProxGradientRegularizer` Class
+
+`ProxGradientRegularizer` class extends the `Regularizer` class to utilize the Proximal Gradient method for optimization. It leverages the `jaxopt` library's Proximal Gradient optimizer, introducing the functionality of a proximal operator.
+
+### Attributes:
+- **`allowed_solvers`**: A list containing string identifiers of optimization solvers compatible with this solver, specifically the "ProximalGradient".
+
+### Methods:
+- **`__init__`**: The constructor method for the `ProxGradientRegularizer` class. It accepts the name of the solver algorithm (`solver_name`), an optional dictionary of additional keyword arguments (`solver_kwargs`) for the solver, the regularization strength (`regularizer_strength`), and an optional mask array (`mask`).
+
+- **`get_prox_operator`**: Abstract method to retrieve the proximal operator for this solver.
+
+- **`instantiate_solver`**: Method to prepare and return a runner function for optimization with a provided loss function and proximal operator.
+
+## `Lasso` Class
+
+`Lasso` class extends `ProxGradientRegularizer` to specialize in optimization using the Lasso (L1 regularization) method with Proximal Gradient.
+
+### Methods:
+- **`__init__`**: Constructor method similar to `ProxGradientRegularizer` but defaults `solver_name` to "ProximalGradient".
+
+- **`get_prox_operator`**: Method to retrieve the proximal operator for Lasso regularization (L1 penalty).
+
+## `GroupLasso` Class
+
+`GroupLasso` class extends `ProxGradientRegularizer` to specialize in optimization using the Group Lasso regularization method with Proximal Gradient. It induces sparsity on groups of features rather than individual features.
+
+### Attributes:
+- **`mask`**: A mask array indicating groups of features for regularization.
+
+### Methods:
+- **`__init__`**: Constructor method similar to `ProxGradientRegularizer`, but additionally requires a `mask` array to identify groups of features.
+
+- **`get_prox_operator`**: Method to retrieve the proximal operator for Group Lasso regularization.
+
+- **`_check_mask`**: Static method to check that the provided mask is a float `jax.numpy.ndarray` of 0s and 1s. The mask must be in floats to be applied correctly through the linear algebra operations of the `nemos.proimal_operator.prox_group_lasso` function.
+
+### Example Usage
+```python
+lasso = Lasso(regularizer_strength=1.0)
+runner = lasso.instantiate_solver(loss_function)
+optim_results = runner(init_params, exog_vars, endog_vars)
+
+group_lasso = GroupLasso(solver_name="ProximalGradient", mask=group_mask, regularizer_strength=1.0)
+runner = group_lasso.instantiate_solver(loss_function)
+optim_results = runner(init_params, exog_vars, endog_vars)
+```
+
+## Contributor Guidelines
+
+### Implementing `Regularizer` Subclasses
+
+When developing a functional (i.e., concrete) `Regularizer` class:
+
+- **Must** inherit from `Regularizer` or one of its derivatives.
+- **Must** implement the `instantiate_solver` method to tailor the solver instantiation based on the provided loss function.
+- For any Proximal Gradient method, **must** include a `get_prox_operator` method to define the proximal operator.
+- **Must** possess an `allowed_solvers` attribute to list the solver names that are permissible to be used with this regularizer.
+- **May** embed additional attributes and methods such as `mask` and `_check_mask` if required by the specific Solver subclass for handling special optimization scenarios.
+- **May** include a `regularizer_strength` attribute to control the strength of the regularization in scenarios where regularization is applicable.
+- **May** rely on a custom solver implementation for specific optimization problems, but the implementation **must** adhere to the `jaxopt` API.
+
+These guidelines ensure that each Solver subclass adheres to a consistent structure and behavior, facilitating ease of extension and maintenance.
+
+## Glossary
+
+| Term | Description |
+|--------------------| ----------- |
+| **Regularization** | Regularization is a technique used to prevent overfitting by adding a penalty to the loss function, which discourages complex models. Common regularization techniques include L1 (Lasso) and L2 (Ridge) regularization. |
+| **Optimization** | Optimization refers to the process of minimizing (or maximizing) a function by systematically choosing the values of the variables within an allowable set. In machine learning, optimization aims to minimize the loss function to train models. |
+| **Solver** | A solver is an algorithm or a set of algorithms used for solving optimization problems. In the given module, solvers are used to find the parameters that minimize the loss function, potentially subject to some constraints. |
+| **Runner** | A runner in this context refers to a callable function configured to execute the solver with the specified parameters and data. It acts as an interface to the solver, simplifying the process of running optimization tasks. |
diff --git a/docs/developers_notes/05-glm.md b/docs/developers_notes/05-glm.md
new file mode 100644
index 00000000..dc01329d
--- /dev/null
+++ b/docs/developers_notes/05-glm.md
@@ -0,0 +1,76 @@
+# The `glm` Module
+
+## Introduction
+
+
+
+Generalized Linear Models (GLM) provide a flexible framework for modeling a variety of data types while establishing a relationship between multiple predictors and a response variable. A GLM extends the traditional linear regression by allowing for response variables that have error distribution models other than a normal distribution, such as binomial or Poisson distributions.
+
+The `nemos.glm` module currently offers implementations of two GLM classes:
+
+1. **`GLM`:** A direct implementation of a feedforward GLM.
+2. **`RecurrentGLM`:** An implementation of a recurrent GLM. This class inherits from `GLM` and redefines the `simulate` method to generate spikes akin to a recurrent neural network.
+
+Our design aligns with the `scikit-learn` API, facilitating seamless integration of our GLM classes with the well-established `scikit-learn` pipeline and its cross-validation tools.
+
+The classes provided here are modular by design offering a standard foundation for any GLM variant.
+
+Instantiating a specific GLM simply requires providing an observation model (Gamma, Poisson, etc.) and a regularization strategies (Ridge, Lasso, etc.) during initialization. This is done using the [`nemos.observation_models.Observations`](../03-observation_models/#the-abstract-class-observations) and [`nemos.regularizer.Regularizer`](../04-regularizer/#the-abstract-class-regularizer) objects, respectively.
+
+
+
+
+
+
+## The Concrete Class `GLM`
+
+The `GLM` class provides a direct implementation of the GLM model and is designed with `scikit-learn` compatibility in mind.
+
+### Inheritance
+
+`GLM` inherits from [`BaseRegressor`](../02-base_class/#the-abstract-class-baseregressor). This inheritance mandates the direct implementation of methods like `predict`, `fit`, `score`, and `simulate`.
+
+### Attributes
+
+- **`regularizer`**: Refers to the optimization regularizer - an object of the [`nemos.regularizer.regularizer`](../04-regularizer/#the-abstract-class-regularizer) type. It uses the `jaxopt` solver to minimize the (penalized) negative log-likelihood of the GLM.
+- **`observation_models`**: Represents the GLM observation model, which is an object of the [`nemos.observation_models.Observations`](../03-observation_models/#the-abstract-class-observations) type. This model determines the log-likelihood and the emission probability mechanism for the `GLM`.
+- **`coef_`**: Stores the solution for spike basis coefficients as `jax.ndarray` after the fitting process. It is initialized as `None` during class instantiation.
+- **`intercept_`**: Stores the bias terms' solutions as `jax.ndarray` after the fitting process. It is initialized as `None` during class instantiation.
+- **`solver_state`**: Indicates the solver's state. For specific solver states, refer to the [`jaxopt` documentation](https://jaxopt.github.io/stable/index.html#).
+
+### Public Methods
+
+- **`predict`**: Validates input and computes the mean rates of the `GLM` by invoking the inverse-link function of the `observation_models` attribute.
+- **`score`**: Validates input and assesses the Poisson GLM using either log-likelihood or pseudo-$R^2$. This method uses the `observation_models` to determine log-likelihood or pseudo-$R^2$.
+- **`fit`**: Validates input and aligns the Poisson GLM with spike train data. It leverages the `observation_models` and `regularizer` to define the model's loss function and instantiate the regularizer.
+- **`simulate`**: Simulates spike trains using the GLM as a feedforward network, invoking the `observation_models.sample_generator` method for emission probability.
+
+### Private Methods
+
+- **`_predict`**: Forecasts rates based on current model parameters and the inverse-link function of the `observation_models`.
+- **`_score`**: Determines the Poisson negative log-likelihood, excluding normalization constants.
+- **`_check_is_fit`**: Validates whether the model has been appropriately fit by ensuring model parameters are set. If not, a `NotFittedError` is raised.
+
+
+## The Concrete Class `RecurrentGLM`
+
+The `RecurrentGLM` class is an extension of the `GLM`, designed to simulate models with recurrent connections. It inherits the `predict`, `fit`, and `score` methods from `GLM`, but provides its own implementation for the `simulate` method.
+
+### Overridden Methods
+
+- **`simulate`**: This method simulates spike trains, treating the GLM as a recurrent neural network. It utilizes the `observation_models.sample_generator` method to determine the emission probability.
+
+## Contributor Guidelines
+
+### Implementing GLM Subclasses
+
+When crafting a functional (i.e., concrete) GLM class:
+
+- **Must** inherit from `BaseRegressor` or one of its derivatives.
+- **Must** realize the `predict`, `fit`, `score`, and `simulate` methods, either directly or through inheritance.
+- **Should** incorporate a `observation_models` attribute of type `nemos.observation_models.Observations` to specify the link-function, emission probability, and likelihood.
+- **Should** include a `regularizer` attribute of type `nemos.regularizer.Regularizer` to instantiate the solver based on regularization type.
+- **May** embed additional parameter and input checks if required by the specific GLM subclass.
diff --git a/docs/developers_notes/GLM_scheme.jpg b/docs/developers_notes/GLM_scheme.jpg
new file mode 100644
index 00000000..712a98ba
Binary files /dev/null and b/docs/developers_notes/GLM_scheme.jpg differ
diff --git a/docs/developers_notes/README.md b/docs/developers_notes/README.md
index 9c6b26ac..59c9b4bd 100644
--- a/docs/developers_notes/README.md
+++ b/docs/developers_notes/README.md
@@ -2,6 +2,7 @@
Welcome to the Developer Notes of the `nemos` project. These notes aim to provide detailed technical information about the various modules, classes, and functions that make up this library, as well as guidelines on how to write code that integrates nicely with our package. They are intended to help current and future developers understand the design decisions, structure, and functioning of the library, and to provide guidance on how to modify, extend, and maintain the codebase.
+
## Intended Audience
These notes are primarily intended for the following groups:
diff --git a/docs/examples/README.md b/docs/examples/README.md
index ec619b83..8ad350e2 100644
--- a/docs/examples/README.md
+++ b/docs/examples/README.md
@@ -1,3 +1,3 @@
# Examples
-This will contain tutorials and examples of package usage.
\ No newline at end of file
+A gallery of tutorials on the current `nemos` functionalities.
\ No newline at end of file
diff --git a/docs/examples/plot_1D_basis_function.py b/docs/examples/plot_1D_basis_function.py
index b5d040e2..b6b8bf94 100644
--- a/docs/examples/plot_1D_basis_function.py
+++ b/docs/examples/plot_1D_basis_function.py
@@ -12,8 +12,9 @@
- The order of the spline, which should be an integer greater than 1.
"""
-import numpy as np
import matplotlib.pylab as plt
+import numpy as np
+
import nemos as nmo
# Initialize hyperparameters
diff --git a/docs/examples/plot_ND_basis_function.py b/docs/examples/plot_ND_basis_function.py
index 5752fcca..5f34a679 100644
--- a/docs/examples/plot_ND_basis_function.py
+++ b/docs/examples/plot_ND_basis_function.py
@@ -63,8 +63,9 @@
# $$
# Here, we simply add two basis objects, `a_basis` and `b_basis`, together to define the additive basis.
-import numpy as np
import matplotlib.pyplot as plt
+import numpy as np
+
import nemos as nmo
# Define 1D basis objects
diff --git a/docs/examples/plot_example_convolution.py b/docs/examples/plot_example_convolution.py
index dd2b493e..c40934d6 100644
--- a/docs/examples/plot_example_convolution.py
+++ b/docs/examples/plot_example_convolution.py
@@ -6,9 +6,10 @@
# ## Generate synthetic data
# Generate some simulated spike counts.
-import numpy as np
-import matplotlib.pylab as plt
import matplotlib.patches as patches
+import matplotlib.pylab as plt
+import numpy as np
+
import nemos as nmo
np.random.seed(10)
diff --git a/docs/examples/plot_glm_demo.py b/docs/examples/plot_glm_demo.py
new file mode 100644
index 00000000..93bc6445
--- /dev/null
+++ b/docs/examples/plot_glm_demo.py
@@ -0,0 +1,375 @@
+"""
+# GLM Demo: Toy Model Examples
+
+!!! warning
+ This demonstration is currently in its alpha stage. It presents various regularization techniques on
+ GLMs trained on a Gaussian noise stimuli, and a minimal example of fitting and simulating a pair of coupled
+ neurons. More work needs to be done to properly compare the performance of the regularization strategies on
+ realistic simulations and real neural recordings.
+
+## Introduction
+
+In this demo we will work through two toy example of a Poisson-GLM on synthetic data: a purely feed-forward input model
+and a recurrently coupled model.
+
+In particular, we will learn how to:
+
+- Define & configurate a GLM object.
+- Fit the model
+- Cross-validate the model with `sklearn`
+- Simulate spike trains.
+
+Before digging into the GLM module, let's first import the packages
+ we are going to use for this tutorial, and generate some synthetic
+ data.
+
+"""
+
+import jax
+import matplotlib.pyplot as plt
+import numpy as np
+from matplotlib.patches import Rectangle
+from sklearn import model_selection
+
+import nemos as nmo
+from nemos import simulation
+
+# enable float64 precision (optional)
+jax.config.update("jax_enable_x64", True)
+
+np.random.seed(111)
+# random design tensor. Shape (n_time_points, n_neurons, n_features).
+X = 0.5*np.random.normal(size=(100, 1, 5))
+
+# log-rates & weights, shape (n_neurons, ) and (n_neurons, n_features) respectively.
+b_true = np.zeros((1, ))
+w_true = np.random.normal(size=(1, 5))
+
+# sparsify weights
+w_true[0, 1:4] = 0.
+
+# generate counts
+rate = jax.numpy.exp(jax.numpy.einsum("ik,tik->ti", w_true, X) + b_true[None, :])
+spikes = np.random.poisson(rate)
+
+# %%
+# ## The Feed-Forward GLM
+#
+# ### Model Definition
+# The class implementing the feed-forward GLM is `nemos.glm.GLM`.
+# In order to define the class, one **must** provide:
+#
+# - **Observation Model**: The observation model for the GLM, e.g. an object of the class of type
+# `nemos.observation_models.Observations`. So far, only the `PoissonObservations`
+# model has been implemented.
+# - **Regularizer**: The desired regularizer, e.g. an object of the `nemos.regularizer.Regularizer` class.
+# Currently, we implemented the un-regularized, Ridge, Lasso, and Group-Lasso regularization.
+#
+# The default for the GLM class is the `PoissonObservations` with log-link function with a Ridge regularization.
+# Here is how to define the model.
+
+# default Poisson GLM with Ridge regularization and Poisson observation model.
+model = nmo.glm.GLM()
+
+print("Regularization type: ", type(model.regularizer))
+print("Observation model:", type(model.observation_model))
+
+# %%
+# ### Model Configuration
+# One could visualize the model hyperparameters by calling `get_params` method.
+
+# get the glm model parameters only
+print("\nGLM model parameters:")
+for key, value in model.get_params(deep=False).items():
+ print(f"\t- {key}: {value}")
+
+# get the glm model parameters, including the all the
+# attributes
+print("\nNested parameters:")
+for key, value in model.get_params(deep=True).items():
+ if key in model.get_params(deep=False):
+ continue
+ print(f"\t- {key}: {value}")
+
+# %%
+# These parameters can be configured at initialization and/or
+# set after the model is initialized with the following syntax:
+
+# Poisson observation model with soft-plus NL
+observation_models = nmo.observation_models.PoissonObservations(jax.nn.softplus)
+
+# Observation model
+regularizer = nmo.regularizer.Ridge(
+ solver_name="LBFGS",
+ regularizer_strength=0.1,
+ solver_kwargs={"tol":10**-10}
+)
+
+# define the GLM
+model = nmo.glm.GLM(
+ observation_model=observation_models,
+ regularizer=regularizer,
+)
+
+print("Regularizer type: ", type(model.regularizer))
+print("Observation model:", type(model.observation_model))
+
+# %%
+# Hyperparameters can be set at any moment via the `set_params` method.
+
+model.set_params(
+ regularizer=nmo.regularizer.Lasso(),
+ observation_model__inverse_link_function=jax.numpy.exp
+)
+
+print("Updated regularizer: ", model.regularizer)
+print("Updated NL: ", model.observation_model.inverse_link_function)
+
+# %%
+# !!! warning
+# Each `Regularizer` has an associated attribute `Regularizer.allowed_solvers`
+# which lists the optimizers that are suited for each optimization problem.
+# For example, a `Ridge` is differentiable and can be fit with `GradientDescent`
+# , `BFGS`, etc., while a `Lasso` should use the `ProximalGradient` method instead.
+# If the provided `solver_name` is not listed in the `allowed_solvers` this will raise an
+# exception.
+
+# %%
+# ### Model Fit
+# Fitting the model is as straight forward as calling the `model.fit`
+# providing the design tensor and the population counts.
+# Additionally one may provide an initial parameter guess.
+# The same exact syntax works for any configuration.
+
+# fit a ridge regression Poisson GLM
+model = nmo.glm.GLM()
+model.set_params(regularizer__regularizer_strength=0.1)
+model.fit(X, spikes)
+
+print("Ridge results")
+print("True weights: ", w_true)
+print("Recovered weights: ", model.coef_)
+
+# %%
+# ## K-fold Cross Validation with `sklearn`
+# Our implementation follows the `scikit-learn` api, this enables us
+# to take advantage of the `scikit-learn` tool-box seamlessly, while at the same time
+# we take advantage of the `jax` GPU acceleration and auto-differentiation in the
+# back-end.
+#
+# Here is an example of how we can perform 5-fold cross-validation via `scikit-learn`.
+#
+# **Ridge**
+
+parameter_grid = {"regularizer__regularizer_strength": np.logspace(-1.5, 1.5, 6)}
+# in practice, you should use more folds than 2, but for the purposes of this
+# demo, 2 is sufficient.
+cls = model_selection.GridSearchCV(model, parameter_grid, cv=2)
+cls.fit(X, spikes)
+
+print("Ridge results ")
+print("Best hyperparameter: ", cls.best_params_)
+print("True weights: ", w_true)
+print("Recovered weights: ", cls.best_estimator_.coef_)
+
+# %%
+# We can compare the Ridge cross-validated results with other regularization schemes.
+#
+# **Lasso**
+
+model.set_params(regularizer=nmo.regularizer.Lasso())
+cls = model_selection.GridSearchCV(model, parameter_grid, cv=2)
+cls.fit(X, spikes)
+
+print("Lasso results ")
+print("Best hyperparameter: ", cls.best_params_)
+print("True weights: ", w_true)
+print("Recovered weights: ", cls.best_estimator_.coef_)
+
+# %%
+# **Group Lasso**
+
+# define groups by masking. Mask size (n_groups, n_features)
+mask = np.zeros((2, 5))
+mask[0, [0, -1]] = 1
+mask[1, 1:-1] = 1
+
+regularizer = nmo.regularizer.GroupLasso("ProximalGradient", mask=mask)
+model.set_params(regularizer=regularizer)
+cls = model_selection.GridSearchCV(model, parameter_grid, cv=2)
+cls.fit(X, spikes)
+
+print("\nGroup Lasso results")
+print("Group mask: :")
+print(mask)
+print("Best hyperparameter: ", cls.best_params_)
+print("True weights: ", w_true)
+print("Recovered weights: ", cls.best_estimator_.coef_)
+
+# %%
+# ## Simulate Spikes
+# We can generate spikes in response to a feedforward-stimuli
+# through the `model.simulate` method.
+
+# here we are creating a new data input, of 20 timepoints (arbitrary)
+# with the same number of neurons and features (mandatory)
+Xnew = np.random.normal(size=(20, ) + X.shape[1:])
+# generate a random key given a seed
+random_key = jax.random.PRNGKey(123)
+spikes, rates = model.simulate(random_key, Xnew)
+
+plt.figure()
+plt.eventplot(np.where(spikes)[0])
+
+
+# %%
+# ## Recurrently Coupled GLM
+# Defining a recurrent model follows the same syntax. In this example
+# we will simulate two coupled neurons, and we will inject a transient
+# input driving the rate of one of the neurons.
+
+
+# Neural population parameters
+n_neurons = 2
+coupling_filter_duration = 100
+
+# %%
+# We can now to define coupling filters that we will use to simulate
+# the pairwise interactions between the neurons. We will model the
+# filters as a difference of two Gamma probability density function.
+# The negative component will capture inhibitory effects such as the
+# refractory period of a neuron, while the positive component will
+# describe excitation.
+
+np.random.seed(101)
+
+# Gamma parameter for the inhibitory component of the fi;ter
+inhib_a = 1
+inhib_b = 1
+
+# Gamma parameters for the excitatory component of the filter
+excit_a = np.random.uniform(1.1, 5, size=(n_neurons, n_neurons))
+excit_b = np.random.uniform(1.1, 5, size=(n_neurons, n_neurons))
+
+# define 2x2 coupling filters of the specific with create_temporal_filter
+coupling_filter_bank = np.zeros((coupling_filter_duration, n_neurons, n_neurons))
+for unit_i in range(n_neurons):
+ for unit_j in range(n_neurons):
+ coupling_filter_bank[:, unit_i, unit_j] = nmo.simulation.difference_of_gammas(
+ coupling_filter_duration,
+ inhib_a=inhib_a,
+ excit_a=excit_a[unit_i, unit_j],
+ inhib_b=inhib_b,
+ excit_b=excit_b[unit_i, unit_j],
+ )
+
+# shrink the filters for simulation stability
+coupling_filter_bank *= 0.8
+
+# %%
+# If we represent our filters in terms of basis functions, we can simulate our network by
+# directly calling the `simulate` method of the `nmo.glm.GLMRecurrent` class.
+
+# define a basis function
+n_basis_funcs = 20
+basis = nmo.basis.RaisedCosineBasisLog(n_basis_funcs)
+
+# approximate the coupling filters in terms of the basis function
+_, coupling_basis = basis.evaluate_on_grid(coupling_filter_bank.shape[0])
+coupling_coeff = simulation.regress_filter(coupling_filter_bank, coupling_basis)
+intercept = -4 * np.ones(n_neurons)
+
+# %%
+# We can check that our approximation worked by plotting the original filters
+# and the basis expansion
+
+# plot coupling functions
+n_basis_coupling = coupling_basis.shape[1]
+fig, axs = plt.subplots(n_neurons, n_neurons)
+plt.suptitle("Coupling filters")
+for unit_i in range(n_neurons):
+ for unit_j in range(n_neurons):
+ axs[unit_i, unit_j].set_title(f"unit {unit_j} -> unit {unit_i}")
+ coeff = coupling_coeff[unit_i, unit_j]
+ axs[unit_i, unit_j].plot(coupling_filter_bank[:, unit_i, unit_j], label="gamma difference")
+ axs[unit_i, unit_j].plot(np.dot(coupling_basis, coeff), ls="--", color="k", label="basis function")
+axs[0, 0].legend()
+plt.tight_layout()
+
+# %%
+# Define a squared stimulus current for the first neuron, and no stimulus for
+# the second neuron
+
+# define a squared current parameters
+simulation_duration = 1000
+stimulus_onset = 200
+stimulus_offset = 500
+stimulus_intensity = 1.5
+
+# create the input tensor of shape (n_samples, n_neurons, n_dimension_stimuli)
+feedforward_input = np.zeros((simulation_duration, n_neurons, 1))
+# inject square input to the first neuron only
+feedforward_input[stimulus_onset: stimulus_offset, 0] = stimulus_intensity
+
+# plot the input
+fig, axs = plt.subplots(1,2)
+plt.suptitle("Feedforward inputs")
+axs[0].set_title("Input to neuron 0")
+axs[0].plot(feedforward_input[:, 0])
+
+axs[1].set_title("Input to neuron 1")
+axs[1].plot(feedforward_input[:, 1])
+axs[1].set_ylim(axs[0].get_ylim())
+
+
+# the input for the simulation will be the dot product
+# of input_coeff with the feedforward_input
+input_coeff = np.ones((n_neurons, 1))
+
+# stack the coefficients in a single matrix
+basis_coeff = np.hstack((coupling_coeff.reshape(n_neurons, -1), input_coeff))
+
+# initialize the spikes for the recurrent simulation
+init_spikes = np.zeros((coupling_filter_duration, n_neurons))
+
+# %%
+# We can now simulate spikes by calling the `simulate_recurrent` method.
+
+model = nmo.glm.GLMRecurrent()
+model.coef_ = jax.numpy.asarray(basis_coeff)
+model.intercept_ = jax.numpy.asarray(intercept)
+
+
+# call simulate, with both the recurrent coupling
+# and the input
+spikes, rates = model.simulate_recurrent(
+ jax.random.PRNGKey(123),
+ feedforward_input=feedforward_input,
+ coupling_basis_matrix=coupling_basis,
+ init_y=init_spikes
+)
+
+# %%
+# And finally plot the results for both neurons.
+
+# mkdocs_gallery_thumbnail_number = 4
+plt.figure()
+ax = plt.subplot(111)
+
+ax.spines['top'].set_visible(False)
+ax.spines['right'].set_visible(False)
+
+patch = Rectangle((200, -0.011), 300, 0.15, alpha=0.2, color="grey")
+
+p0, = plt.plot(rates[:, 0])
+p1, = plt.plot(rates[:, 1])
+
+plt.vlines(np.where(spikes[:, 0])[0], 0.00, 0.01, color=p0.get_color(), label="rate neuron 0")
+plt.vlines(np.where(spikes[:, 1])[0], -0.01, 0.00, color=p1.get_color(), label="rate neuron 1")
+plt.plot(np.exp(basis_coeff[0, -1] * feedforward_input[:, 0, 0] + intercept[0]), color='k', lw=0.8, label="stimulus")
+ax.add_patch(patch)
+plt.ylim(-0.011, .13)
+plt.ylabel("count/bin")
+plt.legend()
+
+
diff --git a/docs/javascripts/katex.js b/docs/javascripts/katex.js
index f7fd7047..250961c8 100644
--- a/docs/javascripts/katex.js
+++ b/docs/javascripts/katex.js
@@ -4,7 +4,8 @@ document$.subscribe(({ body }) => {
{ left: "$$", right: "$$", display: true },
{ left: "$", right: "$", display: false },
{ left: "\\(", right: "\\)", display: false },
- { left: "\\[", right: "\\]", display: true }
+ { left: "\\[", right: "\\]", display: true },
+ { left: "\\begin{aligned}", right: "\\end{aligned}", display: true }
],
})
})
diff --git a/mkdocs.yml b/mkdocs.yml
index 016666d8..abf651ff 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -7,6 +7,16 @@ theme:
primary: 'light blue' # The primary color palette for the theme
features:
- navigation.tabs # Enable navigation tabs feature for the theme
+ markdown_extensions:
+ - md_in_html
+ - admonition
+ - tables
+
+# if footnotes is defined in theme doesn't work
+# If md_in_html is defined outside theme, it also results in
+# an error when building the docs.
+markdown_extensions:
+ - footnotes
plugins:
- search
@@ -26,6 +36,7 @@ plugins:
docstring_style: numpy
show_source: true
members_order: source
+ inherited_members: true
extra_javascript:
- javascripts/katex.js
@@ -40,4 +51,3 @@ nav:
- Tutorials: generated/gallery # Link to the generated gallery as Tutorials
- For Developers: developers_notes/ # Link to the developers notes
- Code References: reference/ # Link to the reference/ directory
-
diff --git a/pyproject.toml b/pyproject.toml
index 5baa7857..778b0916 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -11,7 +11,7 @@ authors = [
{name = "Guillaume Vejo", email = "gviejo@flatironinstitute.org"},
{name = "Alex Williams", email = "alex.h.williams@nyu.edu"}
]
-description = "Toolbox for basic Generalized Linear Models (GLMs) for neural data analysis"
+description = "NEural MOdelS, a statistical modeling framework for neuroscience."
readme = "README.md"
requires-python = ">=3.8"
keywords = ["neuroscience", "Poisson-GLM"]
@@ -31,7 +31,6 @@ dependencies = [
'jaxopt>=0.6', # Optimization library built on JAX
'matplotlib>=3.7', # Plotting library
'numpy>1.20', # Numerical computing library
- 'scikit-learn>=1.2', # Machine learning library
'scipy>=1.10', # Scientific computing library
'typing_extensions>=4.6' # Typing extensions for Python
]
@@ -54,15 +53,18 @@ dev = [
"flake8", # Code linter
"coverage", # Test coverage measurement
"pytest-cov", # Test coverage plugin for pytest
+ "statsmodels", # Used to compare model pseudo-r2 in testing
+ "scikit-learn", # Testing compatibility with CV & pipelines
]
docs = [
"mkdocs", # Documentation generator
"mkdocstrings[python]", # Python-specific plugin for mkdocs
"mkdocs-section-index", # Plugin for generating a section index in mkdocs
"mkdocs-gen-files", # Plugin for generating additional files in mkdocs
- "mkdocs-literate-nav", # Plugin for literate-style navigation in mkdocs
+ "mkdocs-literate-nav>=0.6.1", # Plugin for literate-style navigation in mkdocs
"mkdocs-gallery", # Plugin for adding image galleries to mkdocs
- "mkdocs-material"
+ "mkdocs-material",
+ "mkdocs-autorefs>=0.5"
]
@@ -96,8 +98,8 @@ profile = "black"
# Configure pytest
[tool.pytest.ini_options]
-addopts = "--cov=nemos" # Additional options to pass to pytest, enabling coverage for the 'nemos' package
testpaths = ["tests"] # Specify the directory where test files are located
+addopts = "--cov=src"
[tool.coverage.report]
exclude_lines = [
diff --git a/src/nemos/__init__.py b/src/nemos/__init__.py
index 82a888b7..8938c85c 100644
--- a/src/nemos/__init__.py
+++ b/src/nemos/__init__.py
@@ -1,3 +1,12 @@
#!/usr/bin/env python3
-from . import basis, glm, sample_points, utils
+from . import (
+ basis,
+ exceptions,
+ glm,
+ observation_models,
+ regularizer,
+ sample_points,
+ simulation,
+ utils,
+)
diff --git a/src/nemos/base_class.py b/src/nemos/base_class.py
new file mode 100644
index 00000000..4ee3ada6
--- /dev/null
+++ b/src/nemos/base_class.py
@@ -0,0 +1,450 @@
+"""Abstract class for estimators."""
+
+import abc
+import inspect
+import warnings
+from collections import defaultdict
+from typing import Any, Optional, Tuple, Union
+
+import jax
+import jax.numpy as jnp
+from numpy.typing import ArrayLike, NDArray
+
+from .utils import check_invalid_entry
+
+
+class Base:
+ """Base class for nemos estimators.
+
+ A base class for estimators with utilities for getting and setting parameters,
+ and for interacting with specific devices like CPU, GPU, and TPU.
+
+ This class provides utilities for:
+ - Getting and setting parameters using introspection.
+ - Sending arrays to target devices (CPU, GPU, TPU).
+
+ Parameters
+ ----------
+ **kwargs : dict
+ Arbitrary keyword arguments.
+
+ Notes
+ -----
+ The class provides helper methods mimicking scikit-learn's get_params and set_params.
+ Additionally, it has methods for selecting target devices and sending arrays to them.
+ """
+
+ def get_params(self, deep=True):
+ """
+ From scikit-learn, get parameters by inspecting init.
+
+ Parameters
+ ----------
+ deep
+
+ Returns
+ -------
+ out:
+ A dictionary containing the parameters. Key is the parameter
+ name, value is the parameter value.
+ """
+ out = dict()
+ for key in self._get_param_names():
+ value = getattr(self, key)
+ if deep and hasattr(value, "get_params") and not isinstance(value, type):
+ deep_items = value.get_params().items()
+ out.update((key + "__" + k, val) for k, val in deep_items)
+ out[key] = value
+ return out
+
+ def set_params(self, **params: Any):
+ """Set the parameters of this estimator.
+
+ The method works on simple estimators as well as on nested objects
+ (such as :class:`~sklearn.pipeline.Pipeline`). The latter have
+ parameters of the form ``__`` so that it's
+ possible to update each component of a nested object.
+
+ Parameters
+ ----------
+ **params : dict
+ Estimator parameters.
+
+ Returns
+ -------
+ self : estimator instance
+ Estimator instance.
+ """
+ if not params:
+ # Simple optimization to gain speed (inspect is slow)
+ return self
+ valid_params = self.get_params(deep=True)
+ nested_params: defaultdict = defaultdict(dict) # grouped by prefix
+ for key, value in params.items():
+ key, delim, sub_key = key.partition("__")
+ if key not in valid_params:
+ local_valid_params = self._get_param_names()
+ raise ValueError(
+ f"Invalid parameter {key!r} for estimator {self}. "
+ f"Valid parameters are: {local_valid_params!r}."
+ )
+
+ if delim:
+ nested_params[key][sub_key] = value
+ else:
+ setattr(self, key, value)
+ valid_params[key] = value
+
+ for key, sub_params in nested_params.items():
+ # TODO(1.4): remove specific handling of "base_estimator".
+ # The "base_estimator" key is special. It was deprecated and
+ # renamed to "estimator" for several estimators. This means we
+ # need to translate it here and set sub-parameters on "estimator",
+ # but only if the user did not explicitly set a value for
+ # "base_estimator".
+ if (
+ key == "base_estimator"
+ and valid_params[key] == "deprecated"
+ and self.__module__.startswith("sklearn.")
+ ):
+ warnings.warn(
+ (
+ f"Parameter 'base_estimator' of {self.__class__.__name__} is"
+ " deprecated in favor of 'estimator'. See"
+ f" {self.__class__.__name__}'s docstring for more details."
+ ),
+ FutureWarning,
+ stacklevel=2,
+ )
+ key = "estimator"
+ valid_params[key].set_params(**sub_params)
+
+ return self
+
+ @classmethod
+ def _get_param_names(cls):
+ """Get parameter names for the estimator."""
+ # fetch the constructor or the original constructor before
+ # deprecation wrapping if any
+ init = getattr(cls.__init__, "deprecated_original", cls.__init__)
+ if init is object.__init__:
+ # No explicit constructor to introspect
+ return []
+
+ # introspect the constructor arguments to find the model parameters
+ # to represent
+ init_signature = inspect.signature(init)
+ # Consider the constructor parameters excluding 'self'
+ parameters = [
+ p
+ for p in init_signature.parameters.values()
+ if p.name != "self" and p.kind != p.VAR_KEYWORD
+ ]
+ for p in parameters:
+ if p.kind == p.VAR_POSITIONAL:
+ raise RuntimeError(
+ "GLM estimators should always "
+ "specify their parameters in the signature"
+ " of their __init__ (no varargs)."
+ " %s with constructor %s doesn't "
+ " follow this convention." % (cls, init_signature)
+ )
+
+ # Consider the constructor parameters excluding 'self'
+ parameters = [
+ p.name for p in init_signature.parameters.values() if p.name != "self"
+ ]
+
+ # remove kwargs
+ if "kwargs" in parameters:
+ parameters.remove("kwargs")
+ # Extract and sort argument names excluding 'self'
+ return sorted(parameters)
+
+
+class BaseRegressor(Base, abc.ABC):
+ """Abstract base class for GLM regression models.
+
+ This class encapsulates the common functionality for Generalized Linear Models (GLM)
+ regression models. It provides an abstraction for fitting the model, making predictions,
+ scoring the model, simulating responses, and preprocessing data. Concrete classes
+ are expected to provide specific implementations of the abstract methods defined here.
+
+ See Also
+ --------
+ Concrete models:
+
+ - [`GLM`](../glm/#nemos.glm.GLM): A feed-forward GLM implementation.
+ - [`GLMRecurrent`](../glm/#nemos.glm.GLMRecurrent): A recurrent GLM implementation.
+ """
+
+ @abc.abstractmethod
+ def fit(self, X: Union[NDArray, jnp.ndarray], y: Union[NDArray, jnp.ndarray]):
+ """Fit the model to neural activity."""
+ pass
+
+ @abc.abstractmethod
+ def predict(self, X: Union[NDArray, jnp.ndarray]) -> jnp.ndarray:
+ """Predict rates based on fit parameters."""
+ pass
+
+ @abc.abstractmethod
+ def score(
+ self,
+ X: Union[NDArray, jnp.ndarray],
+ y: Union[NDArray, jnp.ndarray],
+ # may include score_type or other additional model dependent kwargs
+ **kwargs,
+ ) -> jnp.ndarray:
+ """Score the predicted firing rates (based on fit) to the target neural activity."""
+ pass
+
+ @abc.abstractmethod
+ def simulate(
+ self,
+ random_key: jax.random.PRNGKeyArray,
+ feed_forward_input: Union[NDArray, jnp.ndarray],
+ ):
+ """Simulate neural activity in response to a feed-forward input and recurrent activity."""
+ pass
+
+ @staticmethod
+ def _check_and_convert_params(
+ params: Tuple[ArrayLike, ArrayLike], data_type: Optional[jnp.dtype] = None
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
+ """
+ Validate the dimensions and consistency of parameters and data.
+
+ This function checks the consistency of shapes and dimensions for model
+ parameters.
+ It ensures that the parameters and data are compatible for the model.
+
+ """
+ try:
+ params = tuple(jnp.asarray(par, dtype=data_type) for par in params)
+ except (ValueError, TypeError):
+ raise TypeError(
+ "Initial parameters must be array-like of array-like objects "
+ "with numeric data-type!"
+ )
+
+ if len(params) != 2:
+ raise ValueError("Params needs to be array-like of length two.")
+
+ if params[0].ndim != 2:
+ raise ValueError(
+ "params[0] must be of shape (n_neurons, n_features), but"
+ f"params[0] has {params[0].ndim} dimensions!"
+ )
+ if params[1].ndim != 1:
+ raise ValueError(
+ "params[1] must be of shape (n_neurons,) but "
+ f"params[1] has {params[1].ndim} dimensions!"
+ )
+ return params
+
+ @staticmethod
+ def _check_input_dimensionality(
+ X: Optional[jnp.ndarray] = None, y: Optional[jnp.ndarray] = None
+ ):
+ if not (y is None):
+ if y.ndim != 2:
+ raise ValueError(
+ "y must be two-dimensional, with shape (n_timebins, n_neurons)"
+ )
+ if not (X is None):
+ if X.ndim != 3:
+ raise ValueError(
+ "X must be three-dimensional, with shape (n_timebins, n_neurons, n_features)"
+ )
+
+ @staticmethod
+ def _check_input_and_params_consistency(
+ params: Tuple[jnp.ndarray, jnp.ndarray],
+ X: Optional[jnp.ndarray] = None,
+ y: Optional[jnp.ndarray] = None,
+ ):
+ """
+ Validate the number of neurons in model parameters and input arguments.
+
+ Raises
+ ------
+ ValueError
+ - if the number of neurons is inconsistent across the model parameters (`params`) and
+ any additional inputs (`X` or `y` when provided).
+ - if the number of features is inconsistent between params[1] and X (when provided).
+
+ """
+ n_neurons = params[0].shape[0]
+ if n_neurons != params[1].shape[0]:
+ raise ValueError(
+ "Model parameters have inconsistent shapes. "
+ "Spike basis coefficients must be of shape (n_neurons, n_features), and "
+ "bias terms must be of shape (n_neurons,) but n_neurons doesn't look the same in both! "
+ f"Coefficients n_neurons: {params[0].shape[0]}, bias n_neurons: {params[1].shape[0]}"
+ )
+
+ if y is not None:
+ if y.shape[1] != n_neurons:
+ raise ValueError(
+ "The number of neurons in the model parameters and in the inputs"
+ "must match."
+ f"parameters has n_neurons: {n_neurons}, "
+ f"the input provided has n_neurons: {y.shape[1]}"
+ )
+
+ if X is not None:
+ if X.shape[1] != n_neurons:
+ raise ValueError(
+ "The number of neurons in the model parameters and in the inputs"
+ "must match."
+ f"parameters has n_neurons: {n_neurons}, "
+ f"the input provided has n_neurons: {X.shape[1]}"
+ )
+ if params[0].shape[1] != X.shape[2]:
+ raise ValueError(
+ "Inconsistent number of features. "
+ f"spike basis coefficients has {params[0].shape[1]} features, "
+ f"X has {X.shape[2]} features instead!"
+ )
+
+ @staticmethod
+ def _check_input_n_timepoints(X: jnp.ndarray, y: jnp.ndarray):
+ if X.shape[0] != y.shape[0]:
+ raise ValueError(
+ "The number of time-points in X and y must agree. "
+ f"X has {X.shape[0]} time-points, "
+ f"y has {y.shape[0]} instead!"
+ )
+
+ def _preprocess_fit(
+ self,
+ X: Union[NDArray, jnp.ndarray],
+ y: Union[NDArray, jnp.ndarray],
+ init_params: Optional[Tuple[ArrayLike, ArrayLike]] = None,
+ ) -> Tuple[jnp.ndarray, jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
+ """Preprocess input data and initial parameters for the fit method.
+
+ This method carries out the following preprocessing steps:
+
+ - Convert to jax.numpy.ndarray
+
+ - Check the dimensionality of the inputs.
+
+ - Check for any NaNs or Infs in the inputs.
+
+ - If `init_params` is not provided, initialize it with default values.
+
+ - Validate the consistency of input dimensions with the initial parameters.
+
+ Parameters
+ ----------
+ X :
+ Input data, expected to be of shape (n_timebins, n_neurons, n_features).
+ y :
+ Target values, expected to be of shape (n_timebins, n_neurons).
+ init_params :
+ Initial parameters for the model. If None, they are initialized with default values.
+
+ Returns
+ -------
+ X :
+ Preprocessed input data `X` converted to jnp.ndarray.
+ y :
+ Target values `y` converted to jnp.ndarray.
+ init_param :
+ Initialized parameters converted to jnp.ndarray.
+
+ Raises
+ ------
+ ValueError
+ If there are inconsistencies in the input shapes or if NaNs or Infs are detected.
+ """
+ X, y = jnp.asarray(X, dtype=float), jnp.asarray(y, dtype=float)
+
+ # check input dimensionality
+ self._check_input_dimensionality(X, y)
+ self._check_input_n_timepoints(X, y)
+
+ check_invalid_entry(X, "X")
+ check_invalid_entry(y, "y")
+
+ _, n_neurons = y.shape
+ n_features = X.shape[2]
+
+ # Initialize parameters
+ if init_params is None:
+ # Ws, spike basis coeffs
+ init_params = (
+ jnp.zeros((n_neurons, n_features)),
+ # bs, bias terms
+ jnp.log(jnp.mean(y, axis=0)),
+ )
+ else:
+ # check parameter length, shape and dimensionality, convert to jnp.ndarray.
+ init_params = self._check_and_convert_params(init_params)
+
+ # check that the inputs and the parameters has consistent sizes
+ self._check_input_and_params_consistency(init_params, X=X, y=y)
+
+ return X, y, init_params
+
+ def _preprocess_simulate(
+ self,
+ feedforward_input: Union[NDArray, jnp.ndarray],
+ params_feedforward: Tuple[jnp.ndarray, jnp.ndarray],
+ init_y: Optional[Union[NDArray, jnp.ndarray]] = None,
+ params_recurrent: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None,
+ ) -> Tuple[jnp.ndarray, ...]:
+ """
+ Preprocess the input data and parameters for simulation.
+
+ This method handles the conversion of the input data to `jnp.ndarray`, checks the
+ input's dimensionality, and ensures the input's consistency with the provided parameters.
+ It also verifies that the feedforward input does not have any invalid entries (NaNs or Infs).
+
+ Parameters
+ ----------
+ feedforward_input :
+ Input data for the feedforward process. Expected shape: (n_timesteps, n_neurons, n_basis_input).
+ params_feedforward :
+ Parameters corresponding to the feedforward input. Expected shape: (n_neurons, n_basis_input).
+ init_y :
+ Initial values for the feedback process. If provided, its dimensionality and consistency
+ with params_r will be checked. Expected shape if provided: (window_size, n_neurons).
+ params_recurrent :
+ Parameters corresponding to the feedback input (init_y). Required if init_y is provided.
+ Expected shape if provided: (window_size, n_basis_coupling)
+
+ Returns
+ -------
+ :
+ Preprocessed input data, optionally with the initial values for feedback if provided.
+
+ Raises
+ ------
+ ValueError
+ If the feedforward_input contains NaNs or Infs.
+ If the dimensionality or consistency checks fail for the provided data and parameters.
+ """
+ feedforward_input = jnp.asarray(feedforward_input, dtype=float)
+ self._check_input_dimensionality(X=feedforward_input)
+ self._check_input_and_params_consistency(
+ params_feedforward, X=feedforward_input
+ )
+
+ check_invalid_entry(feedforward_input, "feedforward_input")
+
+ # Ensure that both or neither of `init_y` and `params_r` are provided
+ if (init_y is None) != (params_recurrent is None):
+ raise ValueError(
+ "Both `init_y` and `params_r` should be provided, or neither should be provided."
+ )
+ # If both are provided, perform checks and conversions
+ elif init_y is not None and params_recurrent is not None:
+ init_y = jnp.asarray(init_y, dtype=float)
+ self._check_input_dimensionality(y=init_y)
+ self._check_input_and_params_consistency(params_recurrent, y=init_y)
+ return feedforward_input, init_y
+
+ return (feedforward_input,)
diff --git a/src/nemos/basis.py b/src/nemos/basis.py
index a837f71b..428ac417 100644
--- a/src/nemos/basis.py
+++ b/src/nemos/basis.py
@@ -63,7 +63,7 @@ def evaluate(self, *xi: NDArray) -> NDArray:
pass
@staticmethod
- def _get_samples(*n_samples: int) -> Generator[NDArray, ...]:
+ def _get_samples(*n_samples: int) -> Generator[NDArray]:
"""Get equi-spaced samples for all the input dimensions.
This will be used to evaluate the basis on a grid of
@@ -517,12 +517,11 @@ class MSplineBasis(SplineBasis):
at each interior knot. The higher this number, the smoother the basis
representation will be.
-
References
----------
- .. [1] Ramsay, J. O. (1988). Monotone regression splines in action.
- Statistical science, 3(4), 425-441.
-
+ [^1]:
+ Ramsay, J. O. (1988). Monotone regression splines in action.
+ Statistical science, 3(4), 425-441.
"""
def __init__(self, n_basis_funcs: int, order: int = 2) -> None:
@@ -597,8 +596,9 @@ class BSplineBasis(SplineBasis):
References
----------
- ..[2] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques.
- Mathematics and Visualization. Springer, Berlin, Heidelberg. https://doi.org/10.1007/978-3-662-04919-8_5
+ [^2]:
+ Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques.
+ Mathematics and Visualization. Springer, Berlin, Heidelberg. https://doi.org/10.1007/978-3-662-04919-8_5
"""
@@ -847,10 +847,11 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
class RaisedCosineBasisLinear(RaisedCosineBasis):
- """Linearly-spaced raised cosine basis functions used by Pillow et al. [2]_.
+ """Linearly-spaced raised cosine basis functions used by Pillow et al.
These are "cosine bumps" that uniformly tile the space.
+
Parameters
----------
n_basis_funcs
@@ -858,10 +859,11 @@ class RaisedCosineBasisLinear(RaisedCosineBasis):
References
----------
- .. [2] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J.,
- C. E. (2005). Prediction and decoding of retinal ganglion cell responses
- with a probabilistic spiking model. Journal of Neuroscience, 25(47),
- 11003–11013. http://dx.doi.org/10.1523/jneurosci.3305-05.2005
+ [^3]:
+ Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J.,
+ C. E. (2005). Prediction and decoding of retinal ganglion cell responses
+ with a probabilistic spiking model. Journal of Neuroscience, 25(47),
+ 11003–11013. http://dx.doi.org/10.1523/jneurosci.3305-05.2005
"""
diff --git a/src/nemos/exceptions.py b/src/nemos/exceptions.py
new file mode 100644
index 00000000..bcb44de7
--- /dev/null
+++ b/src/nemos/exceptions.py
@@ -0,0 +1,24 @@
+"""Model specific exceptions."""
+
+
+class NotFittedError(ValueError, AttributeError):
+ """Exception class to raise if estimator is used before fitting.
+
+ This class inherits from both ValueError and AttributeError to help with
+ exception handling and backward compatibility.
+
+ Examples
+ --------
+ >>> from nemos.glm import GLM
+ >>> from nemos.exceptions import NotFittedError
+ >>> try:
+ ... GLM().predict([[[1, 2], [2, 3], [3, 4]]])
+ ... except NotFittedError as e:
+ ... print(repr(e))
+
+ NotFittedError("This GLM instance is not fitted yet. Call 'fit' with
+ appropriate arguments.")
+
+ .. versionchanged:: 0.18
+ Moved from sklearn.utils.validation.
+ """
diff --git a/src/nemos/glm.py b/src/nemos/glm.py
index d2614ecd..6e128afa 100644
--- a/src/nemos/glm.py
+++ b/src/nemos/glm.py
@@ -1,274 +1,247 @@
-"""GLM core module
-"""
-import inspect
-import warnings
-from typing import Callable, Optional, Tuple
+"""GLM core module."""
+from typing import Literal, Optional, Tuple, Union
import jax
import jax.numpy as jnp
-import jaxopt
from numpy.typing import NDArray
-from sklearn.exceptions import NotFittedError
-from .utils import convolve_1d_trials
+from . import observation_models as obs
+from . import regularizer as reg
+from . import utils
+from .base_class import BaseRegressor
+from .exceptions import NotFittedError
-class GLM:
- """Generalized Linear Model for neural responses.
+class GLM(BaseRegressor):
+ """
+ Generalized Linear Model (GLM) for neural activity data.
- No stimulus / external variables, only connections to other neurons.
+ This GLM implementation allows users to model neural activity based on a combination of exogenous inputs
+ (like convolved currents or light intensities) and a choice of observation model. It is suitable for scenarios where
+ the relationship between predictors and the response variable might be non-linear, and the residuals
+ don't follow a normal distribution.
Parameters
----------
- spike_basis_matrix : (n_basis_funcs, window_size)
- Matrix of basis functions to use for this GLM. Most likely the output
- of ``Basis.gen_basis_funcs()``
- solver_name
- Name of the solver to use when fitting the GLM. Must be an attribute of
- ``jaxopt``.
- solver_kwargs
- Dictionary of keyword arguments to pass to the solver during its
- initialization.
- inverse_link_function
- Function to transform outputs of convolution with basis to firing rate.
- Must accept any number as input and return all non-negative values.
+ observation_model :
+ Observation model to use. The model describes the distribution of the neural activity.
+ Default is the Poisson model.
+ regularizer :
+ Regularization to use for model optimization. Defines the regularization scheme, the optimization algorithm,
+ and related parameters.
+ Default is Ridge regression with gradient descent.
Attributes
----------
- solver
- jaxopt solver, set during ``fit()``
- solver_state
- state of the solver, set during ``fit()``
- spike_basis_coeff_ : jnp.ndarray, (n_neurons, n_basis_funcs, n_neurons)
- Solutions for the spike basis coefficients, set during ``fit()``
- baseline_log_fr : jnp.ndarray, (n_neurons,)
- Solutions for bias terms, set during ``fit()``
-
+ intercept_ :
+ Model baseline linked firing rate parameters, e.g. if the link is the logarithm, the baseline
+ firing rate will be `jnp.exp(model.intercept_)`.
+ coef_ :
+ Basis coefficients for the model.
+ solver_state :
+ State of the solver after fitting. May include details like optimization error.
+
+ Raises
+ ------
+ TypeError
+ If provided `regularizer` or `observation_model` are not valid.
"""
def __init__(
self,
- spike_basis_matrix: NDArray,
- solver_name: str = "GradientDescent",
- solver_kwargs: dict = dict(),
- inverse_link_function: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.softplus,
- ):
- # (n_basis_funcs, window_size)
- self.spike_basis_matrix = spike_basis_matrix
- self.solver_name = solver_name
- try:
- solver_args = inspect.getfullargspec(getattr(jaxopt, solver_name)).args
- except AttributeError:
- raise AttributeError(
- f"module jaxopt has no attribute {solver_name}, pick a different solver!"
- )
- for k in solver_kwargs.keys():
- if k not in solver_args:
- raise NameError(
- f"kwarg {k} in solver_kwargs is not a kwarg for jaxopt.{solver_name}!"
- )
- self.solver_kwargs = solver_kwargs
- self.inverse_link_function = inverse_link_function
-
- def fit(
- self,
- spike_data: NDArray,
- init_params: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None,
+ observation_model: obs.Observations = obs.PoissonObservations(),
+ regularizer: reg.Regularizer = reg.Ridge("GradientDescent"),
):
- """Fit GLM to spiking data.
+ super().__init__()
- Following scikit-learn API, the solutions are stored as attributes
- ``spike_basis_coeff_`` and ``baseline_log_fr``.
+ self.observation_model = observation_model
+ self.regularizer = regularizer
- Parameters
- ----------
- spike_data : (n_neurons, n_timebins)
- Spike counts arranged in a matrix.
- init_params : ((n_neurons, n_basis_funcs, n_neurons), (n_neurons,))
- Initial values for the spike basis coefficients and bias terms. If
- None, we initialize with zeros.
-
- Raises
- ------
- ValueError
- If spike_data is not two-dimensional.
- ValueError
- If shapes of init_params are not correct.
- ValueError
- If solver returns at least one NaN parameter, which means it found
- an invalid solution. Try tuning optimization hyperparameters.
-
- """
- if spike_data.ndim != 2:
- raise ValueError(
- "spike_data must be two-dimensional, with shape (n_neurons, n_timebins)"
- )
+ # initialize to None fit output
+ self.intercept_ = None
+ self.coef_ = None
+ self.solver_state = None
- n_neurons, _ = spike_data.shape
- n_basis_funcs, window_size = self.spike_basis_matrix.shape
-
- # Convolve spikes with basis functions. We drop the last sample, as
- # those are the features that could be used to predict spikes in the
- # next time bin
- X = jnp.transpose(
- convolve_1d_trials(self.spike_basis_matrix.T, [spike_data.T])[0], (1, 2, 0)
- )[:, :, :-1]
-
- # Initialize parameters
- if init_params is None:
- # Ws, spike basis coeffs
- init_params = (
- jnp.zeros((n_neurons, n_basis_funcs, n_neurons)),
- # bs, bias terms
- jnp.zeros(n_neurons),
- )
+ @property
+ def regularizer(self):
+ """Getter for the regularizer attribute."""
+ return self._regularizer
- if init_params[0].ndim != 3:
- raise ValueError(
- "spike basis coefficients must be of shape (n_neurons, n_basis_funcs, n_neurons), but"
- f" init_params[0] has {init_params[0].ndim} dimensions!"
- )
- if init_params[0].shape[0] != init_params[0].shape[-1]:
- raise ValueError(
- "spike basis coefficients must be of shape (n_neurons, n_basis_funcs, n_neurons), but"
- f" init_params[0] has shape {init_params[0].shape}!"
- )
- if init_params[1].ndim != 1:
- raise ValueError(
- "bias terms must be of shape (n_neurons,) but init_params[0] have"
- f"{init_params[1].ndim} dimensions!"
- )
- if init_params[0].shape[0] != init_params[1].shape[0]:
- raise ValueError(
- "spike basis coefficients must be of shape (n_neurons, n_basis_funcs, n_neurons), and"
- "bias terms must be of shape (n_neurons,) but n_neurons doesn't look the same in both!"
- f"init_params[0]: {init_params[0].shape[0]}, init_params[1]: {init_params[1].shape[0]}"
+ @regularizer.setter
+ def regularizer(self, regularizer: reg.Regularizer):
+ """Setter for the regularizer attribute."""
+ if not hasattr(regularizer, "instantiate_solver"):
+ raise AttributeError(
+ "The provided `solver` doesn't implement the `instantiate_solver` method."
)
- if init_params[0].shape[0] != spike_data.shape[0]:
- raise ValueError(
- "spike basis coefficients must be of shape (n_neurons, n_basis_funcs, n_neurons), and"
- "spike_data must be of shape (n_neurons, n_timebins) but n_neurons doesn't look the same in both!"
- f"init_params[0]: {init_params[0].shape[0]}, spike_data: {spike_data.shape[0]}"
+ # test solver instantiation on the GLM loss
+ try:
+ regularizer.instantiate_solver(self._predict_and_compute_loss)
+ except Exception:
+ raise TypeError(
+ "The provided `solver` cannot be instantiated on "
+ "the GLM log-likelihood."
)
-
- def loss(params, X, y):
- predicted_firing_rates = self._predict(params, X)
- return self._score(predicted_firing_rates, y)
-
- # Run optimization
- solver = getattr(jaxopt, self.solver_name)(fun=loss, **self.solver_kwargs)
- params, state = solver.run(init_params, X=X, y=spike_data[:, window_size:])
-
- if jnp.isnan(params[0]).any() or jnp.isnan(params[1]).any():
- raise ValueError(
- "Solver returned at least one NaN parameter, so solution is invalid!"
- " Try tuning optimization hyperparameters."
+ self._regularizer = regularizer
+
+ @property
+ def observation_model(self):
+ """Getter for the observation_model attribute."""
+ return self._observation_model
+
+ @observation_model.setter
+ def observation_model(self, observation: obs.Observations):
+ # check that the model has the required attributes
+ # and that the attribute can be called
+ obs.check_observation_model(observation)
+ self._observation_model = observation
+
+ def _check_is_fit(self):
+ """Ensure the instance has been fitted."""
+ if (self.coef_ is None) or (self.intercept_ is None):
+ raise NotFittedError(
+ "This GLM instance is not fitted yet. Call 'fit' with appropriate arguments."
)
- # Store parameters
- self.spike_basis_coeff_ = params[0]
- self.baseline_log_fr_ = params[1]
- # note that this will include an error value, which is not the same as
- # the output of loss. I believe it's the output of
- # solver.l2_optimality_error
- self.solver_state = state
- self.solver = solver
def _predict(
- self, params: Tuple[jnp.ndarray, jnp.ndarray], convolved_spike_data: NDArray
+ self, params: Tuple[jnp.ndarray, jnp.ndarray], X: jnp.ndarray
) -> jnp.ndarray:
- """Helper function for generating predictions.
+ """
+ Predicts firing rates based on given parameters and design matrix.
- This way, can use same functions during and after fitting.
+ This function computes the predicted firing rates using the provided parameters
+ and model design matrix `X`. It is a streamlined version used internally within
+ optimization routines, where it serves as the loss function. Unlike the `GLM.predict`
+ method, it does not perform any input validation, assuming that the inputs are pre-validated.
- Note that the ``n_timebins`` here is not necessarily the same as in
- public functions: in particular, this method expects the *convolved*
- spike data, which (since we use the "valid" convolutional output) means
- that it will have fewer timebins than the un-convolved data.
Parameters
----------
- params : ((n_neurons, n_basis_funcs, n_neurons), (n_neurons,))
- Values for the spike basis coefficients and bias terms.
- convolved_spike_data : (n_basis_funcs, n_neurons, n_timebins)
- Spike counts convolved with some set of bases functions.
+ params :
+ Tuple containing the spike basis coefficients and bias terms.
+ X :
+ Predictors. Shape (n_time_bins, n_neurons, n_features).
Returns
-------
- predicted_firing_rates : (n_neurons, n_timebins)
- The predicted firing rates.
-
+ :
+ The predicted rates. Shape (n_time_bins, n_neurons).
"""
Ws, bs = params
- return self.inverse_link_function(
- jnp.einsum("nbt,nbj->nt", convolved_spike_data, Ws) + bs[:, None]
+ return self._observation_model.inverse_link_function(
+ jnp.einsum("ik,tik->ti", Ws, X) + bs[None, :]
)
- def _score(
- self, predicted_firing_rates: NDArray, target_spikes: NDArray
- ) -> jnp.ndarray:
- """Score the predicted firing rates against target spike counts.
+ def predict(self, X: Union[NDArray, jnp.ndarray]) -> jnp.ndarray:
+ """Predict rates based on fit parameters.
- This computes the Poisson negative log-likehood.
+ Parameters
+ ----------
+ X :
+ The exogenous variables. Shape (n_time_bins, n_neurons, n_features).
- Note that you can end up with infinities in here if there are zeros in
- ``predicted_firing_rates``. We raise a warning in that case.
+ Returns
+ -------
+ :
+ The predicted rates with shape (n_time_bins, n_neurons).
- Parameters
- ----------
- predicted_firing_rates : (n_neurons, n_timebins)
- The predicted firing rates.
- target_spikes : (n_neurons, n_timebins)
- The target spikes to compare against
+ Raises
+ ------
+ NotFittedError
+ If ``fit`` has not been called first with this instance.
+ ValueError
+ - If `params` is not a JAX pytree of size two.
+ - If weights and bias terms in `params` don't have the expected dimensions.
+ - If the number of neurons in the model parameters and in the inputs do not match.
+ - If `X` is not three-dimensional.
+ - If there's an inconsistent number of features between spike basis coefficients and `X`.
- Returns
- -------
- score : (1,)
- The Poisson negative log-likehood
+ See Also
+ --------
+ - [score](./#nemos.glm.GLM.score)
+ Score predicted rates against target spike counts.
+ - [simulate (feed-forward only)](../glm/#nemos.glm.GLM.simulate)
+ Simulate neural activity in response to a feed-forward input .
+ - [simulate_recurrent (feed-forward + coupling)](../glm/#nemos.glm.GLMRecurrent.simulate_recurrent)
+ Simulate neural activity in response to a feed-forward input
+ using the GLM as a recurrent network.
+ """
+ # check that the model is fitted
+ self._check_is_fit()
+ # extract model params
+ Ws = self.coef_
+ bs = self.intercept_
+
+ X = jnp.asarray(X, dtype=float)
+
+ # check input dimensionality
+ self._check_input_dimensionality(X=X)
+ # check consistency between X and params
+ self._check_input_and_params_consistency((Ws, bs), X=X)
+ return self._predict((Ws, bs), X)
- Notes
- -----
- The Poisson probably mass function is:
+ def _predict_and_compute_loss(
+ self,
+ params: Tuple[jnp.ndarray, jnp.ndarray],
+ X: jnp.ndarray,
+ y: jnp.ndarray,
+ ) -> jnp.ndarray:
+ r"""Predict the rate and compute the negative log-likelihood against neural activity.
- .. math::
- \frac{\lambda^k \exp(-\lambda)}{k!}
+ This method computes the negative log-likelihood up to a constant term. Unlike `score`,
+ it does not conduct parameter checks prior to evaluation. Passed directly to the solver,
+ it serves to establish the optimization objective for learning the model parameters.
- Thus, the negative log of it is:
+ Parameters
+ ----------
+ params :
+ Values for the spike basis coefficients and bias terms. Shape ((n_neurons, n_features), (n_neurons,)).
+ X :
+ The exogenous variables. Shape (n_time_bins, n_neurons, n_features).
+ y :
+ The target activity to compare against. Shape (n_time_bins, n_neurons).
- .. math::
- ¨ -\log{\frac{\lambda^k\exp{-\lambda}}{k!}} &= -[\log(\lambda^k)+\log(\exp{-\lambda})-\log(k!)]
- &= -k\log(\lambda)-\lambda+\log(\Gamma(k+1))
+ Returns
+ -------
+ :
+ The model negative log-likehood. Shape (1,).
- Because $\Gamma(k+1)=k!$, see
- https://en.wikipedia.org/wiki/Gamma_function.
+ """
+ predicted_rate = self._predict(params, X)
+ return self._observation_model.negative_log_likelihood(predicted_rate, y)
- And, in our case, ``target_spikes`` is $k$ and
- ``predicted_firing_rates`` is $\lambda$
+ def score(
+ self,
+ X: Union[NDArray, jnp.ndarray],
+ y: Union[NDArray, jnp.ndarray],
+ score_type: Literal[
+ "log-likelihood", "pseudo-r2-McFadden", "pseudo-r2-Cohen"
+ ] = "pseudo-r2-McFadden",
+ ) -> jnp.ndarray:
+ r"""Evaluate the goodness-of-fit of the model to the observed neural data.
- """
- x = target_spikes * jnp.log(predicted_firing_rates)
- # this is a jax jit-friendly version of saying "put a 0 wherever
- # there's a NaN". we do this because NaNs result from 0*log(0)
- # (log(0)=-inf and any non-zero multiplied by -inf gives the expected
- # +/- inf)
- x = jnp.where(jnp.isnan(x), jnp.zeros_like(x), x)
- # see above for derivation of this.
- return jnp.mean(
- predicted_firing_rates - x + jax.scipy.special.gammaln(target_spikes + 1)
- )
+ This method computes the goodness-of-fit score, which can either be the mean
+ log-likelihood or of two versions of the pseudo-R^2.
+ The scoring process includes validation of input compatibility with the model's
+ parameters, ensuring that the model has been previously fitted and the input data
+ are appropriate for scoring. A higher score indicates a better fit of the model
+ to the observed data.
- def predict(self, spike_data: NDArray) -> jnp.ndarray:
- """Predict firing rates based on fit parameters, for checking against existing data.
Parameters
----------
- spike_data : (n_neurons, n_timebins)
- Spike counts arranged in a matrix. n_neurons must be the same as
- during the fitting of this GLM instance.
+ X :
+ The exogenous variables. Shape (n_time_bins, n_neurons, n_features)
+ y :
+ Neural activity arranged in a matrix. n_neurons must be the same as
+ during the fitting of this GLM instance. Shape (n_time_bins, n_neurons).
+ score_type :
+ Type of scoring: either log-likelihood or pseudo-r2.
Returns
-------
- predicted_firing_rates : (n_neurons, n_timebins - window_size + 1)
- The predicted firing rates.
+ score :
+ The log-likelihood or the pseudo-$R^2$ of the current model.
Raises
------
@@ -276,156 +249,366 @@ def predict(self, spike_data: NDArray) -> jnp.ndarray:
If ``fit`` has not been called first with this instance.
ValueError
If attempting to simulate a different number of neurons than were
- present during fitting (i.e., if ``init_spikes.shape[0] !=
- self.baseline_log_fr_.shape[0]``).
-
- See Also
- --------
- score
- Score predicted firing rates against target spike counts.
- simulate
- Simulate spikes using GLM as a recurrent network, for extrapolating into the future.
+ present during fitting (i.e., if ``init_y.shape[0] !=
+ self.intercept_.shape[0]``).
+
+ Notes
+ -----
+ The log-likelihood is not on a standard scale, its value is influenced by many factors,
+ among which the number of model parameters. The log-likelihood can assume both positive
+ and negative values.
+
+ The Pseudo-$ R^2 $ is not equivalent to the $ R^2 $ value in linear regression. While both
+ provide a measure of model fit, and assume values in the [0,1] range, the methods and
+ interpretations can differ. The Pseudo-$ R^2 $ is particularly useful for generalized linear
+ models when the interpretation of the $ R^2 $ as explained variance does not apply
+ (i.e., when the observations are not Gaussian distributed).
+
+ Why does the traditional $R^2$ is usually a poor measure of performance in GLMs?
+
+ 1. In the context of GLMs the variance and the mean of the observations are related.
+ Ignoring the relation between them can result in underestimating the model
+ performance; for instance, when we model a Poisson variable with large mean we expect an
+ equally large variance. In this scenario, even if our model perfectly captures the mean,
+ the high-variance will result in large residuals and low $R^2$.
+ Additionally, when the mean of the observations varies, the variance will vary too. This
+ violates the "homoschedasticity" assumption, necessary for interpreting the $R^2$ as
+ variance explained.
+ 2. The $R^2$ capture the variance explained when the relationship between the observations and
+ the predictors is linear. In GLMs, the link function sets a non-linear mapping between the predictors
+ and the mean of the observations, compromising the interpretation of the $R^2$.
+
+ Note that it is possible to re-normalized the residuals by a mean-dependent quantity proportional
+ to the model standard deviation (i.e. Pearson residuals). This "rescaled" residual distribution however
+ deviates substantially from normality for counting data with low mean (common for spike counts).
+ Therefore, even the Pearson residuals performs poorly as a measure of fit quality, especially
+ for GLM modeling counting data.
+
+ Refer to the `nmo.observation_models.Observations` concrete subclasses for the likelihood and
+ pseudo-$R^2$ equations.
"""
- try:
- Ws = self.spike_basis_coeff_
- except AttributeError:
- raise NotFittedError(
- "This GLM instance is not fitted yet. Call 'fit' with appropriate arguments."
+ self._check_is_fit()
+ Ws = self.coef_
+ bs = self.intercept_
+
+ X, y = jnp.asarray(X, dtype=float), jnp.asarray(y, dtype=float)
+
+ self._check_input_dimensionality(X, y)
+ self._check_input_n_timepoints(X, y)
+ self._check_input_and_params_consistency((Ws, bs), X=X, y=y)
+
+ if score_type == "log-likelihood":
+ norm_constant = jax.scipy.special.gammaln(y + 1).mean()
+ score = -self._predict_and_compute_loss((Ws, bs), X, y) - norm_constant
+ elif score_type.startswith("pseudo-r2"):
+ score = self._observation_model.pseudo_r2(
+ self._predict((Ws, bs), X), y, score_type=score_type
)
- bs = self.baseline_log_fr_
- if spike_data.shape[0] != bs.shape[0]:
- raise ValueError(
- "Number of neurons must be the same during prediction and fitting! "
- f"spike_data n_neurons: {spike_data.shape[0]}, "
- f"self.baseline_log_fr_ n_neurons: {self.baseline_log_fr_.shape[0]}"
+ else:
+ raise NotImplementedError(
+ f"Scoring method {score_type} not implemented! "
+ "`score_type` must be either 'log-likelihood', 'pseudo-r2-McFadden', "
+ "or 'pseudo-r2-Cohen'."
)
- X = jnp.transpose(
- convolve_1d_trials(self.spike_basis_matrix.T, spike_data.T[None, :, :])[0],
- (1, 2, 0),
- )
- return self._predict((Ws, bs), X)
+ return score
+
+ def fit(
+ self,
+ X: Union[NDArray, jnp.ndarray],
+ y: Union[NDArray, jnp.ndarray],
+ init_params: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None,
+ ):
+ """Fit GLM to neural activity.
+
+ Fit and store the model parameters as attributes
+ ``coef_`` and ``coef_``.
+
+ Parameters
+ ----------
+ X :
+ Predictors, shape (n_time_bins, n_neurons, n_features)
+ y :
+ Neural activity arranged in a matrix, shape (n_time_bins, n_neurons).
+ init_params :
+ Initial values for the activity basis coefficients and bias terms. If
+ None, we initialize with zeros. shape. ((n_neurons, n_features), (n_neurons,))
+
+ Raises
+ ------
+ ValueError
+ - If `init_params` is not of length two.
+ - If dimensionality of `init_params` are not correct.
+ - If the number of neurons in the model parameters and in the inputs do not match.
+ - If `X` is not three-dimensional.
+ - If `y` is not two-dimensional.
+ - If solver returns at least one NaN parameter, which means it found
+ an invalid solution. Try tuning optimization hyperparameters.
+ TypeError
+ - If `init_params` are not array-like
+ - If `init_params[i]` cannot be converted to jnp.ndarray for all i
+ """
+ # convert to jnp.ndarray & perform checks
+ X, y, init_params = self._preprocess_fit(X, y, init_params)
+
+ # Run optimization
+ runner = self.regularizer.instantiate_solver(self._predict_and_compute_loss)
+ params, state = runner(init_params, X, y)
- def score(self, spike_data: NDArray) -> jnp.ndarray:
- """Score the predicted firing rates (based on fit) to the target spike counts.
+ # estimate the GLM scale
+ self.observation_model.estimate_scale(self._predict(params, X))
- This ignores the last time point of the prediction.
+ if jnp.isnan(params[0]).any() or jnp.isnan(params[1]).any():
+ raise ValueError(
+ "Solver returned at least one NaN parameter, so solution is invalid!"
+ " Try tuning optimization hyperparameters."
+ )
+
+ # Store parameters
+ self.coef_: jnp.ndarray = params[0]
+ self.intercept_: jnp.ndarray = params[1]
+ # note that this will include an error value, which is not the same as
+ # the output of loss. I believe it's the output of
+ # solver.l2_optimality_error
+ self.solver_state = state
- This computes the Poisson negative log-likehood, thus the lower the
- number the better, and zero isn't special (you can have a negative
- score if ``spike_data > 0`` and and ``log(predicted_firing_rates) < 0``
+ def simulate(
+ self,
+ random_key: jax.random.PRNGKeyArray,
+ feedforward_input: Union[NDArray, jnp.ndarray],
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
+ """Simulate neural activity in response to a feed-forward input.
Parameters
----------
- spike_data : (n_neurons, n_timebins)
- Spike counts arranged in a matrix. n_neurons must be the same as
- during the fitting of this GLM instance.
+ random_key :
+ PRNGKey for seeding the simulation.
+ feedforward_input :
+ External input matrix to the model, representing factors like convolved currents,
+ light intensities, etc. When not provided, the simulation is done with coupling-only.
+ Expected shape: (n_time_bins, n_neurons, n_basis_input).
Returns
-------
- score : (1,)
- The Poisson negative log-likehood
+ simulated_activity :
+ Simulated activity (spike counts for PoissonGLMs) for each neuron over time.
+ Shape: (n_time_bins, n_neurons).
+ firing_rates :
+ Simulated rates for each neuron over time. Shape, (n_neurons, n_time_bins).
Raises
------
NotFittedError
- If ``fit`` has not been called first with this instance.
+ If the model hasn't been fitted prior to calling this method.
ValueError
- If attempting to simulate a different number of neurons than were
- present during fitting (i.e., if ``init_spikes.shape[0] !=
- self.baseline_log_fr_.shape[0]``).
- UserWarning
- If there are any zeros in ``self.predict(spike_data)``, since this
- will likely lead to infinite log-likelihood values being returned.
+ - If the instance has not been previously fitted.
+ - If there's an inconsistency between the number of neurons in model parameters.
+ - If the number of neurons in input arguments doesn't match with model parameters.
+
+ See Also
+ --------
+ [predict](./#nemos.glm.GLM.predict) :
+ Method to predict rates based on the model's parameters.
"""
- # ignore the last time point from predict, because that corresponds to
- # the next time step, which we have no observed data for
- predicted_firing_rates = self.predict(spike_data)[:, :-1]
- if (predicted_firing_rates == 0).any():
- warnings.warn(
- "predicted_firing_rates array contained zeros, this can "
- "lead to infinite log-likelihood values."
- )
- window_size = self.spike_basis_matrix.shape[1]
- return self._score(predicted_firing_rates, spike_data[:, window_size:])
+ # check if the model is fit
+ self._check_is_fit()
+ Ws, bs = self.coef_, self.intercept_
+ (feedforward_input,) = self._preprocess_simulate(
+ feedforward_input, params_feedforward=(Ws, bs)
+ )
+ predicted_rate = self._predict((Ws, bs), feedforward_input)
+ return (
+ self._observation_model.sample_generator(
+ key=random_key, predicted_rate=predicted_rate
+ ),
+ predicted_rate,
+ )
- def simulate(
+
+class GLMRecurrent(GLM):
+ """
+ A Generalized Linear Model (GLM) with recurrent dynamics.
+
+ This class extends the basic GLM to capture recurrent dynamics between neurons and
+ self-connectivity, making it more suitable for simulating the activity of interconnected
+ neural populations. The recurrent GLM combines both feedforward inputs (like sensory
+ stimuli) and past neural activity to simulate or predict future neural activity.
+
+ Parameters
+ ----------
+ observation_model :
+ The observation model to use for the GLM. This defines how neural activity is generated
+ based on the underlying firing rate. Common choices include Poisson and Gaussian models.
+ regularizer :
+ The regularization scheme to use for fitting the GLM parameters.
+
+ See Also
+ --------
+ [GLM](./#nemos.glm.GLM) : Base class for the generalized linear model.
+
+ Notes
+ -----
+ - The recurrent GLM assumes that neural activity can be influenced by both feedforward
+ inputs and the past activity of the same and other neurons. This makes it particularly
+ powerful for capturing the dynamics of neural networks where neurons are interconnected.
+
+ - The attributes of `GLMRecurrent` are inherited from the parent `GLM` class, and include
+ coefficients, fitted status, and other model-related attributes.
+ """
+
+ def __init__(
+ self,
+ observation_model: obs.Observations = obs.PoissonObservations(),
+ regularizer: reg.Regularizer = reg.Ridge(),
+ ):
+ super().__init__(observation_model=observation_model, regularizer=regularizer)
+
+ def simulate_recurrent(
self,
random_key: jax.random.PRNGKeyArray,
- n_timesteps: int,
- init_spikes: NDArray,
- ) -> jnp.ndarray:
- """Simulate spikes using GLM as a recurrent network, for extrapolating into the future.
+ feedforward_input: Union[NDArray, jnp.ndarray],
+ coupling_basis_matrix: Union[NDArray, jnp.ndarray],
+ init_y: Union[NDArray, jnp.ndarray],
+ ):
+ """
+ Simulate neural activity using the GLM as a recurrent network.
+
+ This function projects neural activity into the future, employing the fitted
+ parameters of the GLM. It is capable of simulating activity based on a combination
+ of historical activity and external feedforward inputs like convolved currents, light
+ intensities, etc.
Parameters
----------
- random_key
- jax PRNGKey to seed simulation with.
- n_timesteps
- Number of time steps to simulate.
- init_spikes : (n_neurons, window_size)
- Spike counts arranged in a matrix. These are used to jump start the
- forward simulation. ``n_neurons`` must be the same as during the
- fitting of this GLM instance and ``window_size`` must be the same
- as the bases functions (i.e., ``self.spike_basis_matrix.shape[1]``)
+ random_key :
+ PRNGKey for seeding the simulation.
+ feedforward_input :
+ External input matrix to the model, representing factors like convolved currents,
+ light intensities, etc. When not provided, the simulation is done with coupling-only.
+ Expected shape: (n_time_bins, n_neurons, n_basis_input).
+ init_y :
+ Initial observation (spike counts for PoissonGLM) matrix that kickstarts the simulation.
+ Expected shape: (window_size, n_neurons).
+ coupling_basis_matrix :
+ Basis matrix for coupling, representing between-neuron couplings
+ and auto-correlations. Expected shape: (window_size, n_basis_coupling).
Returns
-------
- simulated_spikes : (n_neurons, n_timesteps)
- The simulated spikes.
+ simulated_activity :
+ Simulated activity (spike counts for PoissonGLMs) for each neuron over time.
+ Shape, (n_time_bins, n_neurons).
+ firing_rates :
+ Simulated rates for each neuron over time. Shape, (n_time_bins, n_neurons,).
Raises
------
NotFittedError
- If ``fit`` has not been called first with this instance.
+ If the model hasn't been fitted prior to calling this method.
ValueError
- If attempting to simulate a different number of neurons than were
- present during fitting (i.e., if ``init_spikes.shape[0] !=
- self.baseline_log_fr_.shape[0]``) or if ``init_spikes`` has the
- wrong number of time steps (i.e., if ``init_spikes.shape[1] !=
- self.spike_basis_matrix.shape[1]``)
+ - If the instance has not been previously fitted.
+ - If there's an inconsistency between the number of neurons in model parameters.
+ - If the number of neurons in input arguments doesn't match with model parameters.
+
See Also
--------
- predict
- Predict firing rates based on fit parameters, for checking against existing data.
+ [predict](./#nemos.glm.GLM.predict) :
+ Method to predict rates based on the model's parameters.
+
+ Notes
+ -----
+ The model coefficients (`self.coef_`) are structured such that the first set of coefficients
+ (of size `n_basis_coupling * n_neurons`) are interpreted as the weights for the recurrent couplings.
+ The remaining coefficients correspond to the weights for the feed-forward input.
+
+ The sum of `n_basis_input` and `n_basis_coupling * n_neurons` should equal `self.coef_.shape[1]`
+ to ensure consistency in the model's input feature dimensionality.
"""
- try:
- Ws = self.spike_basis_coeff_
- except AttributeError:
- raise NotFittedError(
- "This GLM instance is not fitted yet. Call 'fit' with appropriate arguments."
- )
- bs = self.baseline_log_fr_
+ # check if the model is fit
+ self._check_is_fit()
- if init_spikes.shape[0] != bs.shape[0]:
- raise ValueError(
- "Number of neurons must be the same during simulation and fitting! "
- f"init_spikes n_neurons: {init_spikes.shape[0]}, "
- f"self.baseline_log_fr_ n_neurons: {self.baseline_log_fr_.shape[0]}"
- )
- if init_spikes.shape[1] != self.spike_basis_matrix.shape[1]:
+ # convert to jnp.ndarray
+ coupling_basis_matrix = jnp.asarray(coupling_basis_matrix, dtype=float)
+
+ n_basis_coupling = coupling_basis_matrix.shape[1]
+ n_neurons = self.intercept_.shape[0]
+
+ w_feedforward = self.coef_[:, n_basis_coupling * n_neurons :]
+ w_recurrent = self.coef_[:, : n_basis_coupling * n_neurons]
+ bs = self.intercept_
+
+ feedforward_input, init_y = self._preprocess_simulate(
+ feedforward_input,
+ params_feedforward=(w_feedforward, bs),
+ init_y=init_y,
+ params_recurrent=(w_recurrent, bs),
+ )
+
+ self._check_input_and_params_consistency(
+ (w_recurrent, bs),
+ y=init_y,
+ )
+
+ if init_y.shape[0] != coupling_basis_matrix.shape[0]:
raise ValueError(
- "init_spikes has the wrong number of time steps!"
- f"init_spikes time steps: {init_spikes.shape[1]}, "
- f"spike_basis_matrix window size: {self.spike_basis_matrix.shape[1]}"
+ "`init_y` and `coupling_basis_matrix`"
+ " should have the same window size! "
+ f"`init_y` window size: {init_y.shape[1]}, "
+ f"`coupling_basis_matrix` window size: {coupling_basis_matrix.shape[1]}"
)
- subkeys = jax.random.split(random_key, num=n_timesteps)
+ subkeys = jax.random.split(random_key, num=feedforward_input.shape[0])
+ # (n_samples, n_neurons)
+ feed_forward_contrib = jnp.einsum(
+ "ik,tik->ti", w_feedforward, feedforward_input
+ )
- def scan_fn(spikes, key):
- # (n_neurons, n_basis_funcs, 1)
- X = jnp.transpose(
- convolve_1d_trials(self.spike_basis_matrix.T, spikes.T[None, :, :])[0],
- (1, 2, 0),
+ def scan_fn(
+ data: Tuple[jnp.ndarray, int], key: jax.random.PRNGKeyArray
+ ) -> Tuple[Tuple[jnp.ndarray, int], Tuple[jnp.ndarray, jnp.ndarray]]:
+ """Scan over time steps and simulate activity and rates.
+
+ This function simulates the neural activity and firing rates for each time step
+ based on the previous activity, feedforward input, and model coefficients.
+ """
+ activity, t_sample = data
+
+ # Convolve the neural activity with the coupling basis matrix
+ # Output of shape (1, n_neuron, n_basis_coupling)
+ # 1. The first dimension is time, and 1 is by construction since we are simulating 1
+ # sample
+ # 2. Flatten to shape (n_neuron * n_basis_coupling, )
+ conv_act = utils.convolve_1d_trials(coupling_basis_matrix, activity[None])[
+ 0
+ ].flatten()
+
+ # Extract the slice of the feedforward input for the current time step
+ input_slice = jax.lax.dynamic_slice(
+ feed_forward_contrib,
+ (t_sample, 0),
+ (1, feed_forward_contrib.shape[1]),
+ ).squeeze(axis=0)
+
+ # Predict the firing rate using the model coefficients
+ # Doesn't use predict because the non-linearity needs
+ # to be applied after we add the feed forward input
+ firing_rate = self._observation_model.inverse_link_function(
+ w_recurrent.dot(conv_act) + input_slice + bs
)
- fr = self._predict((Ws, bs), X).squeeze(-1)
- new_spikes = jax.random.poisson(key, fr)
- concat_spikes = jnp.column_stack((spikes[:, 1:], new_spikes))
- return concat_spikes, new_spikes
- _, simulated_spikes = jax.lax.scan(scan_fn, init_spikes, subkeys)
+ # Simulate activity based on the predicted firing rate
+ new_act = self._observation_model.sample_generator(key, firing_rate)
+
+ # Shift of one sample the spike count window
+ # for the next iteration (i.e. remove the first counts, and
+ # stack the newly generated sample)
+ # Increase the t_sample by one
+ carry = jnp.row_stack((activity[1:], new_act)), t_sample + 1
+ return carry, (new_act, firing_rate)
- return simulated_spikes.T
+ _, outputs = jax.lax.scan(scan_fn, (init_y, 0), subkeys)
+ simulated_activity, firing_rates = outputs
+ return simulated_activity, firing_rates
diff --git a/src/nemos/observation_models.py b/src/nemos/observation_models.py
new file mode 100644
index 00000000..ce80bad7
--- /dev/null
+++ b/src/nemos/observation_models.py
@@ -0,0 +1,568 @@
+"""Observation model classes for GLMs."""
+
+import abc
+from typing import Callable, Literal, Union
+
+import jax
+import jax.numpy as jnp
+
+from . import utils
+from .base_class import Base
+
+KeyArray = Union[jnp.ndarray, jax.random.PRNGKeyArray]
+
+__all__ = ["PoissonObservations"]
+
+
+def __dir__():
+ return __all__
+
+
+class Observations(Base, abc.ABC):
+ """
+ Abstract observation model class for neural data processing.
+
+ This is an abstract base class used to implement observation models for neural data.
+ Specific observation models that inherit from this class should define their versions
+ of the abstract methods: negative_log_likelihood, emission_probability, and
+ residual_deviance.
+
+ Attributes
+ ----------
+ inverse_link_function :
+ A function that transforms a set of predictors to the domain of the model parameter.
+
+ See Also
+ --------
+ [PoissonObservations](./#nemos.observation_models.PoissonObservations) : A specific implementation of a
+ observation model using the Poisson distribution.
+ """
+
+ def __init__(self, inverse_link_function: Callable, **kwargs):
+ super().__init__(**kwargs)
+ self.inverse_link_function = inverse_link_function
+ self.scale = 1.0
+
+ @property
+ def inverse_link_function(self):
+ """Getter for the inverse link function for the model."""
+ return self._inverse_link_function
+
+ @inverse_link_function.setter
+ def inverse_link_function(self, inverse_link_function: Callable):
+ """Setter for the inverse link function for the model."""
+ self.check_inverse_link_function(inverse_link_function)
+ self._inverse_link_function = inverse_link_function
+
+ @property
+ def scale(self):
+ """Getter for the scale parameter of the model."""
+ return self._scale
+
+ @scale.setter
+ def scale(self, value: Union[int, float]):
+ """Setter for the scale parameter of the model."""
+ if not isinstance(value, (int, float)):
+ raise ValueError("The `scale` parameter must be of numeric type.")
+ self._scale = value
+
+ @staticmethod
+ def check_inverse_link_function(inverse_link_function: Callable):
+ """
+ Check if the provided inverse_link_function is usable.
+
+ This function verifies if the inverse link function:
+ 1. Is callable
+ 2. Returns a jax.numpy.ndarray
+ 3. Is differentiable (via jax)
+
+ Parameters
+ ----------
+ inverse_link_function :
+ The function to be checked.
+
+ Raises
+ ------
+ TypeError
+ If the function is not callable, does not return a jax.numpy.ndarray,
+ or is not differentiable.
+ """
+ # check that it's callable
+ if not callable(inverse_link_function):
+ raise TypeError("The `inverse_link_function` function must be a Callable!")
+
+ # check if the function returns a jax array for a 1D array
+ array_out = inverse_link_function(jnp.array([1.0, 2.0, 3.0]))
+ if not isinstance(array_out, jnp.ndarray):
+ raise TypeError(
+ "The `inverse_link_function` must return a jax.numpy.ndarray!"
+ )
+
+ # Optionally: Check for scalar input
+ scalar_out = inverse_link_function(1.0)
+ if not isinstance(scalar_out, (jnp.ndarray, float, int)):
+ raise TypeError(
+ "The `inverse_link_function` must handle scalar inputs correctly and return a scalar or a "
+ "jax.numpy.ndarray!"
+ )
+
+ # check for autodiff
+ try:
+ gradient_fn = jax.grad(inverse_link_function)
+ gradient_fn(1.0)
+ except Exception as e:
+ raise TypeError(
+ f"The `inverse_link_function` function cannot be differentiated. Error: {e}"
+ )
+
+ @abc.abstractmethod
+ def negative_log_likelihood(self, predicted_rate, y):
+ r"""Compute the observation model negative log-likelihood.
+
+ This computes the negative log-likelihood of the predicted rates
+ for the observed neural activity up to a constant.
+
+ Parameters
+ ----------
+ predicted_rate :
+ The predicted rate of the current model. Shape (n_time_bins, n_neurons).
+ y :
+ The target activity to compare against. Shape (n_time_bins, n_neurons).
+
+ Returns
+ -------
+ :
+ The negative log-likehood. Shape (1,).
+ """
+ pass
+
+ @abc.abstractmethod
+ def sample_generator(
+ self, key: KeyArray, predicted_rate: jnp.ndarray
+ ) -> jnp.ndarray:
+ """
+ Sample from the estimated distribution.
+
+ This method generates random numbers from the desired distribution based on the given
+ `predicted_rate`.
+
+ Parameters
+ ----------
+ key :
+ Random key used for the generation of random numbers in JAX.
+ predicted_rate :
+ Expected rate of the distribution. Shape (n_time_bins, n_neurons).
+
+ Returns
+ -------
+ :
+ Random numbers generated from the observation model with `predicted_rate`.
+ """
+ pass
+
+ @abc.abstractmethod
+ def deviance(self, predicted_rate: jnp.ndarray, spike_counts: jnp.ndarray):
+ r"""Compute the residual deviance for the observation model.
+
+ Parameters
+ ----------
+ predicted_rate:
+ The predicted firing rates. Shape (n_time_bins, n_neurons).
+ spike_counts:
+ The spike counts. Shape (n_time_bins, n_neurons).
+
+ Returns
+ -------
+ :
+ The residual deviance of the model.
+ """
+ pass
+
+ @abc.abstractmethod
+ def estimate_scale(self, predicted_rate: jnp.ndarray) -> None:
+ r"""Estimate the scale parameter for the model.
+
+ This method estimates the scale parameter, often denoted as $\phi$, which determines the dispersion
+ of an exponential family distribution. The probability density function (pdf) for such a distribution
+ is generally expressed as
+ $f(x; \theta, \phi) \propto \exp \left(a(\phi)\left( y\theta - \mathcal{k}(\theta) \right)\right)$.
+
+ The relationship between variance and the scale parameter is given by:
+ $$
+ \text{var}(Y) = \frac{V(\mu)}{a(\phi)}.
+ $$
+
+ The scale parameter, $\phi$, is necessary for capturing the variance of the data accurately.
+
+ Parameters
+ ----------
+ predicted_rate :
+ The predicted rate values.
+ """
+ pass
+
+ def pseudo_r2(
+ self,
+ predicted_rate: jnp.ndarray,
+ y: jnp.ndarray,
+ score_type: Literal[
+ "pseudo-r2-McFadden", "pseudo-r2-Cohen"
+ ] = "pseudo-r2-McFadden",
+ ) -> jnp.ndarray:
+ r"""Pseudo-$R^2$ calculation for a GLM.
+
+ Compute the pseudo-$R^2$ metric for the GLM, as defined by McFadden et al.[$^1$](#references)
+ or by Cohen et al.[$^2$](#references).
+
+ This metric evaluates the goodness-of-fit of the model relative to a null (baseline) model that assumes a
+ constant mean for the observations. While the pseudo-$R^2$ is bounded between 0 and 1 for the training set,
+ it can yield negative values on out-of-sample data, indicating potential over-fitting.
+
+ Parameters
+ ----------
+ predicted_rate:
+ The mean neural activity. Expected shape: (n_time_bins, n_neurons)
+ y:
+ The neural activity. Expected shape: (n_time_bins, n_neurons)
+ score_type:
+ The pseudo-R$^2$ type.
+
+ Returns
+ -------
+ :
+ The pseudo-$R^2$ of the model. A value closer to 1 indicates a better model fit,
+ whereas a value closer to 0 suggests that the model doesn't improve much over the null model.
+
+ Notes
+ -----
+ - The McFadden pseudo-$R^2$ is given by:
+ $$
+ R^2_{\text{mcf}} = 1 - \frac{\log(L_{M})}{\log(L_0)}.
+ $$
+ *Equivalent to statsmodels
+ [`GLMResults.pseudo_rsquared(kind='mcf')`](https://www.statsmodels.org/dev/generated/statsmodels.genmod.generalized_linear_model.GLMResults.pseudo_rsquared.html).*
+ - The Cohen pseudo-$R^2$ is given by:
+ $$
+ \begin{aligned}
+ R^2_{\text{Cohen}} &= \frac{D_0 - D_M}{D_0} \\\
+ &= 1 - \frac{\log(L_s) - \log(L_M)}{\log(L_s)-\log(L_0)},
+ \end{aligned}
+ $$
+ where $L_M$, $L_0$ and $L_s$ are the likelihood of the fitted model, the null model (a
+ model with only the intercept term), and the saturated model (a model with one parameter per
+ sample, i.e. the maximum value that the likelihood could possibly achieve). $D_M$ and $D_0$ are
+ the model and the null deviance, $D_i = -2 \left[ \log(L_s) - \log(L_i) \right]$ for $i=M,0$.
+
+
+ References
+ ----------
+ 1. McFadden D (1979). Quantitative methods for analysing travel behavior of individuals: Some recent
+ developments. In D. A. Hensher & P. R. Stopher (Eds.), *Behavioural travel modelling* (pp. 279-318).
+ London: Croom Helm.
+ 2. Jacob Cohen, Patricia Cohen, Steven G. West, Leona S. Aiken.
+ *Applied Multiple Regression/Correlation Analysis for the Behavioral Sciences*.
+ 3rd edition. Routledge, 2002. p.502. ISBN 978-0-8058-2223-6. (May 2012)
+ """
+ if score_type == "pseudo-r2-McFadden":
+ pseudo_r2 = self._pseudo_r2_mcfadden(predicted_rate, y)
+ elif score_type == "pseudo-r2-Cohen":
+ pseudo_r2 = self._pseudo_r2_cohen(predicted_rate, y)
+ else:
+ raise NotImplementedError(f"Score {score_type} not implemented!")
+ return pseudo_r2
+
+ def _pseudo_r2_cohen(
+ self, predicted_rate: jnp.ndarray, y: jnp.ndarray
+ ) -> jnp.ndarray:
+ r"""Cohen's pseudo-$R^2$.
+
+ Compute the pseudo-$R^2$ metric as defined by Cohen et al. (2002). See
+ [`pseudo_r2`](#pseudo_r2) for additional information.
+
+ Parameters
+ ----------
+ predicted_rate:
+ The mean neural activity. Expected shape: (n_time_bins, n_neurons)
+ y:
+ The neural activity. Expected shape: (n_time_bins, n_neurons)
+
+ Returns
+ -------
+ :
+ The pseudo-$R^2$ of the model. A value closer to 1 indicates a better model fit,
+ whereas a value closer to 0 suggests that the model doesn't improve much over the null model.
+ """
+ model_dev_t = self.deviance(predicted_rate, y)
+ model_deviance = jnp.sum(model_dev_t)
+
+ null_mu = jnp.ones(y.shape, dtype=jnp.float32) * y.mean()
+ null_dev_t = self.deviance(null_mu, y)
+ null_deviance = jnp.sum(null_dev_t)
+ return (null_deviance - model_deviance) / null_deviance
+
+ def _pseudo_r2_mcfadden(self, predicted_rate: jnp.ndarray, y: jnp.ndarray):
+ """
+ McFadden's pseudo-$R^2$.
+
+ Compute the pseudo-$R^2$ metric as defined by McFadden et al. (1979). See
+ [`pseudo_r2`](#pseudo_r2) for additional information.
+
+ Parameters
+ ----------
+ predicted_rate:
+ The mean neural activity. Expected shape: (n_time_bins, n_neurons)
+ y:
+ The neural activity. Expected shape: (n_time_bins, n_neurons)
+
+ Returns
+ -------
+ :
+ The pseudo-$R^2$ of the model. A value closer to 1 indicates a better model fit,
+ whereas a value closer to 0 suggests that the model doesn't improve much over the null model.
+ """
+ norm = -jax.scipy.special.gammaln(y + 1).mean()
+ mean_y = jnp.ones(y.shape) * y.mean(axis=0)
+ ll_null = -self.negative_log_likelihood(mean_y, y) + norm
+ ll_model = -self.negative_log_likelihood(predicted_rate, y) + norm
+ return 1 - ll_model / ll_null
+
+
+class PoissonObservations(Observations):
+ """
+ Model observations as Poisson random variables.
+
+ The PoissonObservations is designed to model the observed spike counts based on a Poisson distribution
+ with a given rate. It provides methods for computing the negative log-likelihood, generating samples,
+ and computing the residual deviance for the given spike count data.
+
+ Attributes
+ ----------
+ inverse_link_function :
+ A function that maps the predicted rate to the domain of the Poisson parameter. Defaults to jnp.exp.
+
+ See Also
+ --------
+ [Observations](./#nemos.observation_models.Observations) : Base class for observation models.
+ """
+
+ def __init__(self, inverse_link_function=jnp.exp):
+ super().__init__(inverse_link_function=inverse_link_function)
+ self.scale = 1
+
+ def negative_log_likelihood(
+ self,
+ predicted_rate: jnp.ndarray,
+ y: jnp.ndarray,
+ ) -> jnp.ndarray:
+ r"""Compute the Poisson negative log-likelihood.
+
+ This computes the Poisson negative log-likelihood of the predicted rates
+ for the observed spike counts up to a constant.
+
+ Parameters
+ ----------
+ predicted_rate :
+ The predicted rate of the current model. Shape (n_time_bins, n_neurons).
+ y :
+ The target spikes to compare against. Shape (n_time_bins, n_neurons).
+
+ Returns
+ -------
+ :
+ The Poisson negative log-likehood. Shape (1,).
+
+ Notes
+ -----
+ The formula for the Poisson mean log-likelihood is the following,
+
+ $$
+ \begin{aligned}
+ \text{LL}(\hat{\lambda} | y) &= \frac{1}{T \cdot N} \sum_{n=1}^{N} \sum_{t=1}^{T}
+ [y\_{tn} \log(\hat{\lambda}\_{tn}) - \hat{\lambda}\_{tn} - \log({y\_{tn}!})] \\\
+ &= \frac{1}{T \cdot N} \sum_{n=1}^{N} \sum_{t=1}^{T} [y\_{tn} \log(\hat{\lambda}\_{tn}) -
+ \hat{\lambda}\_{tn} - \Gamma({y\_{tn}+1})] \\\
+ &= \frac{1}{T \cdot N} \sum_{n=1}^{N} \sum_{t=1}^{T} [y\_{tn} \log(\hat{\lambda}\_{tn}) -
+ \hat{\lambda}\_{tn}] + \\text{const}
+ \end{aligned}
+ $$
+
+ Because $\Gamma(k+1)=k!$, see [wikipedia](https://en.wikipedia.org/wiki/Gamma_function) for explanation.
+
+ The $\log({y\_{tn}!})$ term is not a function of the parameters and can be disregarded
+ when computing the loss-function. This is why we incorporated it into the `const` term.
+ """
+ predicted_rate = jnp.clip(
+ predicted_rate, a_min=jnp.finfo(predicted_rate.dtype).eps
+ )
+ x = y * jnp.log(predicted_rate)
+ # see above for derivation of this.
+ return jnp.mean(predicted_rate - x)
+
+ def sample_generator(
+ self, key: KeyArray, predicted_rate: jnp.ndarray
+ ) -> jnp.ndarray:
+ """
+ Sample from the Poisson distribution.
+
+ This method generates random numbers from a Poisson distribution based on the given
+ `predicted_rate`.
+
+ Parameters
+ ----------
+ key :
+ Random key used for the generation of random numbers in JAX.
+ predicted_rate :
+ Expected rate (lambda) of the Poisson distribution. Shape (n_time_bins, n_neurons).
+
+ Returns
+ -------
+ jnp.ndarray
+ Random numbers generated from the Poisson distribution based on the `predicted_rate`.
+ """
+ return jax.random.poisson(key, predicted_rate)
+
+ def deviance(
+ self, predicted_rate: jnp.ndarray, spike_counts: jnp.ndarray
+ ) -> jnp.ndarray:
+ r"""Compute the residual deviance for a Poisson model.
+
+ Parameters
+ ----------
+ predicted_rate:
+ The predicted firing rates. Shape (n_time_bins, n_neurons).
+ spike_counts:
+ The spike counts. Shape (n_time_bins, n_neurons).
+
+ Returns
+ -------
+ :
+ The residual deviance of the model.
+
+ Notes
+ -----
+ The deviance is a measure of the goodness of fit of a statistical model.
+ For a Poisson model, the residual deviance is computed as:
+
+ $$
+ \begin{aligned}
+ D(y\_{tn}, \hat{y}\_{tn}) &= 2 \left[ y\_{tn} \log\left(\frac{y\_{tn}}{\hat{y}\_{tn}}\right)
+ - (y\_{tn} - \hat{y}\_{tn}) \right]\\\
+ &= 2 \left( \text{LL}\left(y\_{tn} | y\_{tn}\right) - \text{LL}\left(y\_{tn} | \hat{y}\_{tn}\right)\right)
+ \end{aligned}
+ $$
+
+ where $ y $ is the observed data, $ \hat{y} $ is the predicted data, and $\text{LL}$ is the model
+ log-likelihood. Lower values of deviance indicate a better fit.
+ """
+ # this takes care of 0s in the log
+ ratio = jnp.clip(
+ spike_counts / predicted_rate, jnp.finfo(predicted_rate.dtype).eps, jnp.inf
+ )
+ deviance = 2 * (spike_counts * jnp.log(ratio) - (spike_counts - predicted_rate))
+ return deviance
+
+ def estimate_scale(self, predicted_rate: jnp.ndarray) -> None:
+ r"""
+ Assign 1 to the scale parameter of the Poisson model.
+
+ For the Poisson exponential family distribution, the scale parameter $\phi$ is always 1.
+ This property is consistent with the fact that the variance equals the mean in a Poisson distribution.
+ As given in the general exponential family expression:
+ $$
+ \text{var}(Y) = \frac{V(\mu)}{a(\phi)},
+ $$
+ for the Poisson family, it simplifies to $\text{var}(Y) = \mu$ since $a(\phi) = 1$ and $V(\mu) = \mu$.
+
+ Parameters
+ ----------
+ predicted_rate :
+ The predicted rate values. This is not used in the Poisson model for estimating scale,
+ but is retained for compatibility with the abstract method signature.
+ """
+ self.scale = 1.0
+
+
+def check_observation_model(observation_model):
+ """
+ Check the attributes of an observation model for compliance.
+
+ This function ensures that the observation model has the required attributes and that each
+ attribute is a callable function. Additionally, it checks if these functions return
+ jax.numpy.ndarray objects, and in the case of 'inverse_link_function', whether it is
+ differentiable.
+
+ Parameters
+ ----------
+ observation_model : object
+ An instance of an observation model that should have specific attributes.
+
+ Raises
+ ------
+ AttributeError
+ If the `observation_model` does not have one of the required attributes.
+
+ TypeError
+ - If an attribute is not a callable function.
+ - If a function does not return a jax.numpy.ndarray.
+ - If 'inverse_link_function' is not differentiable.
+
+ Examples
+ --------
+ >>> class MyObservationModel:
+ ... def inverse_link_function(self, x):
+ ... return jax.scipy.special.expit(x)
+ ... def negative_log_likelihood(self, params, y_true):
+ ... return -jnp.sum(y_true * jax.scipy.special.logit(params) + \
+ ... (1 - y_true) * jax.scipy.special.logit(1 - params))
+ ... def pseudo_r2(self, params, y_true):
+ ... return 1 - (self.negative_log_likelihood(params, y_true) /
+ ... jnp.sum((y_true - y_true.mean()) ** 2))
+ ... def sample_generator(self, key, params):
+ ... return jax.random.bernoulli(key, params)
+ >>> model = MyObservationModel()
+ >>> check_observation_model(model) # Should pass without error if the model is correctly implemented.
+ """
+ # Define the checks to be made on each attribute
+ checks = {
+ "inverse_link_function": {
+ "input": [jnp.array([1.0, 1.0, 1.0])],
+ "test_differentiable": True,
+ "test_preserve_shape": False,
+ },
+ "negative_log_likelihood": {
+ "input": [0.5 * jnp.array([1.0, 1.0, 1.0]), jnp.array([1.0, 1.0, 1.0])],
+ "test_scalar_func": True,
+ },
+ "pseudo_r2": {
+ "input": [0.5 * jnp.array([1.0, 1.0, 1.0]), jnp.array([1.0, 1.0, 1.0])],
+ "test_scalar_func": True,
+ },
+ "sample_generator": {
+ "input": [jax.random.PRNGKey(123), 0.5 * jnp.array([1.0, 1.0, 1.0])],
+ "test_preserve_shape": True,
+ },
+ }
+
+ # Perform checks for each attribute
+ for attr_name, check_info in checks.items():
+ # check if the observation model has the attribute
+ utils.assert_has_attribute(observation_model, attr_name)
+
+ # check if the attribute is a callable
+ func = getattr(observation_model, attr_name)
+ utils.assert_is_callable(func, attr_name)
+
+ # check that the callable returns an array
+ utils.assert_returns_ndarray(func, check_info["input"], attr_name)
+
+ if check_info.get("test_differentiable"):
+ utils.assert_differentiable(func, attr_name)
+
+ if "test_preserve_shape" in check_info:
+ index = int(check_info["test_preserve_shape"])
+ utils.assert_preserve_shape(
+ func, check_info["input"], attr_name, input_index=index
+ )
+
+ if check_info.get("test_scalar_func"):
+ utils.assert_scalar_func(func, check_info["input"], attr_name)
diff --git a/src/nemos/proximal_operator.py b/src/nemos/proximal_operator.py
new file mode 100644
index 00000000..76607183
--- /dev/null
+++ b/src/nemos/proximal_operator.py
@@ -0,0 +1,138 @@
+r"""Collection of proximal operators.
+
+Proximal operators are a mathematical tools used to solve non-differentiable optimization
+problems or to simplify complex ones.
+
+A classical use-case for proximal operator is that of minimizing a penalized loss function where the
+penalization is non-differentiable (Lasso, group Lasso etc.). In proximal gradient algorithms, proximal
+operators are used to find the parameters that balance the minimization of the penalty term with
+ the proximity to the gradient descent update of the un-penalized loss.
+
+More formally, proximal operators solve the minimization problem,
+
+$$
+\\text{prox}\_f(\bm{v}) = \arg\min\_{\bm{x}} \left( f(\bm{x}) + \frac{1}{2}\Vert \bm{x} - \bm{v}\Vert_2 ^2 \right)
+$$
+
+
+Where $f$ is usually the non-differentiable penalization term, and $\bm{v}$ is the parameter update of the
+un-penalized loss function. The first term controls the penalization magnitude, the second the proximity
+with the gradient based update.
+
+References
+----------
+[1] Parikh, Neal, and Stephen Boyd. *"Proximal Algorithms, ser. Foundations and Trends (r) in Optimization."* (2013).
+"""
+from typing import Tuple
+
+import jax
+import jax.numpy as jnp
+
+
+def _norm2_masked(weight_neuron: jnp.ndarray, mask: jnp.ndarray) -> jnp.ndarray:
+ """Euclidean norm of the group.
+
+ Calculate the Euclidean norm of the weights for a specified group within a
+ neuron's feature vector.
+
+ This function computes the norm of elements that are indicated by the mask array.
+ If 'mask' were of boolean type, this operation would be equivalent to performing
+ `jnp.linalg.norm(weight_neuron[mask], 2)` followed by dividing the result by the
+ square root of the sum of the mask (assuming each group has at least 1 feature).
+
+ Parameters
+ ----------
+ weight_neuron:
+ The feature vector for a neuron. Shape (n_features, ).
+ mask:
+ The mask vector for group. mask[i] = 1, if the i-th element of weight_neuron
+ belongs to the group, 0 otherwise. Shape (n_features, )
+
+ Returns
+ -------
+ :
+ The norm of the weight vector corresponding to the feature in mask.
+
+ Notes
+ -----
+ The proximal gradient operator is described in Ming at al.[^1], Proposition 1.
+
+ [^1]:
+ Yuan, Ming, and Yi Lin. "Model selection and estimation in regression with grouped variables."
+ Journal of the Royal Statistical Society Series B: Statistical Methodology 68.1 (2006): 49-67.
+ """
+ return jnp.linalg.norm(weight_neuron * mask, 2) / jnp.sqrt(mask.sum())
+
+
+# vectorize the norm function above
+# [(n_neurons, n_features), (n_features)] -> (n_neurons, )
+_vmap_norm2_masked_1 = jax.vmap(_norm2_masked, in_axes=(0, None), out_axes=0)
+# [(n_neurons, n_features), (n_groups, n_features)] -> (n_neurons, n_groups)
+_vmap_norm2_masked_2 = jax.vmap(_vmap_norm2_masked_1, in_axes=(None, 0), out_axes=1)
+
+
+def prox_group_lasso(
+ params: Tuple[jnp.ndarray, jnp.ndarray],
+ regularizer_strength: float,
+ mask: jnp.ndarray,
+ scaling: float = 1.0,
+) -> Tuple[jnp.ndarray, jnp.ndarray]:
+ r"""Proximal gradient operator for group Lasso.
+
+ Parameters
+ ----------
+ params:
+ Weights, shape (n_neurons, n_features); intercept, shape (n_neurons, )
+ regularizer_strength:
+ The regularization hyperparameter.
+ mask:
+ ND array of 0,1 as float32, feature mask. size (n_groups, n_features)
+ scaling:
+ The scaling factor for the group-lasso (it will be set
+ depending on the step-size).
+
+ Returns
+ -------
+ :
+ The rescaled weights.
+
+ Notes
+ -----
+ This function implements the proximal operator for a group-Lasso penalization which
+ can be derived in analytical form.
+ The proximal operator equation are,
+
+ $$
+ \text{prox}(\beta_g) = \text{min}_{\beta} \left[ \lambda \sum\_{g=1}^G \Vert \beta_g \Vert_2 +
+ \frac{1}{2} \Vert \hat{\beta} - \beta \Vert_2^2
+ \right],
+ $$
+ where $G$ is the number of groups, and $\beta_g$ is the parameter vector
+ associated with the $g$-th group.
+ The analytical solution[^1] for the beta is,
+
+ $$
+ \text{prox}(\beta\_g) = \max \left(1 - \frac{\lambda \sqrt{p\_g}}{\Vert \hat{\beta}\_g \Vert_2},
+ 0\right) \cdot \hat{\beta}\_g,
+ $$
+ where $p_g$ is the dimensionality of $\beta\_g$ and $\hat{\beta}$ is typically the gradient step
+ of the un-regularized optimization objective function. It's easy to see how the group-Lasso
+ proximal operator acts as a shrinkage factor for the un-penalize update, and the half-rectification
+ non-linearity that effectively sets to zero group of coefficients satisfying,
+ $$
+ \Vert \hat{\beta}\_g \Vert_2 \le \frac{1}{\lambda \sqrt{p\_g}}.
+ $$
+
+ [^1]:
+ Yuan, Ming, and Yi Lin. "Model selection and estimation in regression with grouped variables."
+ Journal of the Royal Statistical Society Series B: Statistical Methodology 68.1 (2006): 49-67.
+ """
+ weights, intercepts = params
+ # [(n_neurons, n_features), (n_groups, n_features)] -> (n_neurons, n_groups)
+ l2_norm = _vmap_norm2_masked_2(weights, mask)
+ factor = 1 - regularizer_strength * scaling / l2_norm
+ factor = jax.nn.relu(factor)
+ # Avoid shrinkage of features that do not belong to any group
+ # by setting the shrinkage factor to 1.
+ not_regularized = jnp.outer(jnp.ones(factor.shape[0]), 1 - mask.sum(axis=0))
+ return weights * (factor @ mask + not_regularized), intercepts
diff --git a/src/nemos/regularizer.py b/src/nemos/regularizer.py
new file mode 100644
index 00000000..7e4a26e0
--- /dev/null
+++ b/src/nemos/regularizer.py
@@ -0,0 +1,501 @@
+"""
+A Module for Optimization with Various Regularization Schemes.
+
+This module provides a series of classes that facilitate the optimization of models
+with different types of regularizations. Each `Regularizer` class in this module interfaces
+with various optimization methods, and they can be applied depending on the model's requirements.
+"""
+import abc
+import inspect
+from typing import Any, Callable, List, Optional, Tuple, Union
+
+import jax.numpy as jnp
+import jaxopt
+from numpy.typing import NDArray
+
+from . import utils
+from .base_class import Base
+from .proximal_operator import prox_group_lasso
+
+SolverRunner = Callable[
+ [
+ Tuple[
+ jnp.ndarray, jnp.ndarray
+ ], # Model parameters (for now tuple, eventually pytree)
+ jnp.ndarray, # Predictors (i.e. model design for GLM)
+ jnp.ndarray,
+ ], # Output (neural activity)
+ jaxopt.OptStep,
+]
+
+ProximalOperator = Callable[
+ [
+ Tuple[
+ jnp.ndarray, jnp.ndarray
+ ], # Model parameters (for now tuple, eventually pytree)
+ float, # Regularizer strength (for now float, eventually pytree)
+ float,
+ ], # Step-size for optimization (must be a float)
+ Tuple[jnp.ndarray, jnp.ndarray],
+]
+
+__all__ = ["UnRegularized", "Ridge", "Lasso", "GroupLasso"]
+
+
+def __dir__() -> list[str]:
+ return __all__
+
+
+class Regularizer(Base, abc.ABC):
+ """
+ Abstract base class for regularized solvers.
+
+ This class is designed to provide a consistent interface for optimization solvers,
+ enabling users to easily switch between different regularizers, ensuring compatibility
+ with various loss functions and optimization algorithms.
+
+ Attributes
+ ----------
+ allowed_solvers :
+ List of solver names that are allowed for use with this regularizer.
+ solver_name :
+ Name of the solver being used.
+ solver_kwargs :
+ Additional keyword arguments to be passed to the solver during instantiation.
+ """
+
+ allowed_solvers: List[str] = []
+
+ def __init__(
+ self,
+ solver_name: str,
+ solver_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self._check_solver(solver_name)
+ self._solver_name = solver_name
+ if solver_kwargs is None:
+ self._solver_kwargs = dict()
+ else:
+ self._solver_kwargs = solver_kwargs
+ self._check_solver_kwargs(self.solver_name, self.solver_kwargs)
+
+ @property
+ def solver_name(self):
+ return self._solver_name
+
+ @solver_name.setter
+ def solver_name(self, solver_name: str):
+ self._check_solver(solver_name)
+ self._solver_name = solver_name
+
+ @property
+ def solver_kwargs(self):
+ return self._solver_kwargs
+
+ @solver_kwargs.setter
+ def solver_kwargs(self, solver_kwargs: dict):
+ self._check_solver_kwargs(self.solver_name, solver_kwargs)
+ self._solver_kwargs = solver_kwargs
+
+ def _check_solver(self, solver_name: str):
+ """
+ Ensure the provided solver name is allowed.
+
+ Parameters
+ ----------
+ solver_name :
+ Name of the solver to be checked.
+
+ Raises
+ ------
+ ValueError
+ If the provided solver name is not in the list of allowed optimizers.
+ """
+ if solver_name not in self.allowed_solvers:
+ raise ValueError(
+ f"Solver `{solver_name}` not allowed for "
+ f"{self.__class__} regularization. "
+ f"Allowed solvers are {self.allowed_solvers}."
+ )
+
+ @staticmethod
+ def _check_solver_kwargs(solver_name, solver_kwargs):
+ """
+ Check if provided solver keyword arguments are valid.
+
+ Parameters
+ ----------
+ solver_name :
+ Name of the solver.
+ solver_kwargs :
+ Additional keyword arguments for the solver.
+
+ Raises
+ ------
+ NameError
+ If any of the solver keyword arguments are not valid.
+ """
+ solver_args = inspect.getfullargspec(getattr(jaxopt, solver_name)).args
+ undefined_kwargs = set(solver_kwargs.keys()).difference(solver_args)
+ if undefined_kwargs:
+ raise NameError(
+ f"kwargs {undefined_kwargs} in solver_kwargs not a kwarg for jaxopt.{solver_name}!"
+ )
+
+ def instantiate_solver(
+ self, loss: Callable, *args: Any, **kwargs: Any
+ ) -> SolverRunner:
+ """
+ Instantiate the solver with the provided loss function.
+
+ Parameters
+ ----------
+ loss :
+ The loss function to be optimized.
+
+ *args:
+ Positional arguments for the jaxopt `solver.run` method, e.g. the regularizing
+ strength for proximal gradient methods.
+
+ *kwargs:
+ Keyword arguments for the jaxopt `solver.run` method.
+
+ Returns
+ -------
+ :
+ A function that runs the solver with the provided loss and proximal operator.
+ """
+ # check that the loss is Callable
+ utils.assert_is_callable(loss, "loss")
+
+ # get the solver with given arguments.
+ # The "fun" argument is not always the first one, but it is always KEYWORD
+ # see jaxopt.EqualityConstrainedQP for example. The most general way is to pass it as keyword.
+ solver = getattr(jaxopt, self.solver_name)(fun=loss, **self.solver_kwargs)
+
+ def solver_run(
+ init_params: Tuple[jnp.ndarray, jnp.ndarray], *run_args: jnp.ndarray
+ ) -> jaxopt.OptStep:
+ return solver.run(init_params, *args, *run_args, **kwargs)
+
+ return solver_run
+
+
+class UnRegularized(Regularizer):
+ """
+ Solver class for optimizing unregularized models.
+
+ This class provides an interface to various optimization methods for models that
+ do not involve regularization. The optimization methods that are allowed for this
+ class are defined in the `allowed_solvers` attribute.
+
+ Attributes
+ ----------
+ allowed_solvers : list of str
+ List of solver names that are allowed for this regularizer class.
+
+ See Also
+ --------
+ [Regularizer](./#nemos.regularizer.Regularizer) : Base solver class from which this class inherits.
+ """
+
+ allowed_solvers = [
+ "GradientDescent",
+ "BFGS",
+ "LBFGS",
+ "ScipyMinimize",
+ "NonlinearCG",
+ "ScipyBoundedMinimize",
+ "LBFGSB",
+ ]
+
+ def __init__(
+ self, solver_name: str = "GradientDescent", solver_kwargs: Optional[dict] = None
+ ):
+ super().__init__(solver_name, solver_kwargs=solver_kwargs)
+
+
+class Ridge(Regularizer):
+ """
+ Solver for Ridge regularization using various optimization algorithms.
+
+ This class uses `jaxopt` optimizers to perform Ridge regularization. It extends
+ the base Solver class, with the added feature of Ridge penalization.
+
+ Attributes
+ ----------
+ allowed_solvers : List[..., str]
+ A list of solver names that are allowed to be used with this regularizer.
+ """
+
+ allowed_solvers = [
+ "GradientDescent",
+ "BFGS",
+ "LBFGS",
+ "ScipyMinimize",
+ "NonlinearCG",
+ "ScipyBoundedMinimize",
+ "LBFGSB",
+ ]
+
+ def __init__(
+ self,
+ solver_name: str = "GradientDescent",
+ solver_kwargs: Optional[dict] = None,
+ regularizer_strength: float = 1.0,
+ ):
+ super().__init__(solver_name, solver_kwargs=solver_kwargs)
+ self.regularizer_strength = regularizer_strength
+
+ def _penalization(self, params: Tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
+ """
+ Compute the Ridge penalization for given parameters.
+
+ Parameters
+ ----------
+ params :
+ Model parameters for which to compute the penalization.
+
+ Returns
+ -------
+ float
+ The Ridge penalization value.
+ """
+ return (
+ 0.5
+ * self.regularizer_strength
+ * jnp.sum(jnp.power(params[0], 2))
+ / params[1].shape[0]
+ )
+
+ def instantiate_solver(
+ self, loss: Callable, *args: Any, **kwargs: Any
+ ) -> SolverRunner:
+ """
+ Instantiate the solver with a penalized loss function.
+
+ Parameters
+ ----------
+ loss :
+ The original loss function to be optimized.
+
+ Returns
+ -------
+ Callable
+ A function that runs the solver with the penalized loss.
+ """
+ # this check has be performed here because the penalized loss will
+ # always be a callable independently of which loss is passed!
+ utils.assert_is_callable(loss, "loss")
+
+ def penalized_loss(params, X, y):
+ return loss(params, X, y) + self._penalization(params)
+
+ return super().instantiate_solver(penalized_loss)
+
+
+class ProxGradientRegularizer(Regularizer, abc.ABC):
+ """
+ Abstract class for ptimization solvers using the Proximal Gradient method.
+
+ This class utilizes the `jaxopt` library's Proximal Gradient optimizer. It extends
+ the base Solver class, with the added functionality of a proximal operator.
+
+ Attributes
+ ----------
+ allowed_solvers : List[...,str]
+ A list of solver names that are allowed to be used with this regularizer.
+ """
+
+ allowed_solvers = ["ProximalGradient"]
+
+ def __init__(
+ self,
+ solver_name: str,
+ solver_kwargs: Optional[dict] = None,
+ regularizer_strength: float = 1.0,
+ **kwargs,
+ ):
+ if solver_kwargs is None:
+ solver_kwargs = dict(prox=self._get_proximal_operator())
+ else:
+ solver_kwargs["prox"] = self._get_proximal_operator()
+ super().__init__(solver_name, solver_kwargs=solver_kwargs)
+ self.regularizer_strength = regularizer_strength
+
+ @abc.abstractmethod
+ def _get_proximal_operator(
+ self,
+ ) -> ProximalOperator:
+ """
+ Abstract method to retrieve the proximal operator for this solver.
+
+ Returns
+ -------
+ :
+ The proximal operator, which typically applies a form of regularization.
+ """
+ pass
+
+ def instantiate_solver(
+ self, loss: Callable, *args: Any, **kwargs: Any
+ ) -> SolverRunner:
+ """
+ Instantiate the solver with the provided loss function and proximal operator.
+
+ Parameters
+ ----------
+ loss :
+ The original loss function to be optimized.
+
+ Returns
+ -------
+ :
+ A function that runs the solver with the provided loss and proximal operator.
+ """
+ return super().instantiate_solver(loss, self.regularizer_strength)
+
+
+class Lasso(ProxGradientRegularizer):
+ """
+ Optimization solver using the Lasso (L1 regularization) method with Proximal Gradient.
+
+ This class is a specialized version of the ProxGradientSolver with the proximal operator
+ set for L1 regularization (Lasso). It utilizes the `jaxopt` library's proximal gradient optimizer.
+ """
+
+ def __init__(
+ self,
+ solver_name: str = "ProximalGradient",
+ solver_kwargs: Optional[dict] = None,
+ regularizer_strength: float = 1.0,
+ ):
+ super().__init__(solver_name, solver_kwargs=solver_kwargs)
+ self.regularizer_strength = regularizer_strength
+
+ def _get_proximal_operator(
+ self,
+ ) -> ProximalOperator:
+ """
+ Retrieve the proximal operator for Lasso regularization (L1 penalty).
+
+ Returns
+ -------
+ :
+ The proximal operator, applying L1 regularization to the provided parameters. The intercept
+ term is not regularized.
+ """
+
+ def prox_op(params, l1reg, scaling=1.0):
+ Ws, bs = params
+ return jaxopt.prox.prox_lasso(Ws, l1reg, scaling=scaling), bs
+
+ return prox_op
+
+
+class GroupLasso(ProxGradientRegularizer):
+ """
+ Optimization solver using the Group Lasso regularization method with Proximal Gradient.
+
+ This class is a specialized version of the ProxGradientSolver with the proximal operator
+ set for Group Lasso regularization. The Group Lasso regularization induces sparsity on groups
+ of features rather than individual features.
+
+ Attributes
+ ----------
+ mask : Union[jnp.ndarray, NDArray]
+ A 2d mask array indicating groups of features for regularization.
+ Each row represents a group of features.
+ Each column corresponds to a feature, where a value of 1 indicates that the feature belongs
+ to the group, and a value of 0 indicates it doesn't.
+ """
+
+ def __init__(
+ self,
+ solver_name: str,
+ mask: Union[NDArray, jnp.ndarray],
+ solver_kwargs: Optional[dict] = None,
+ regularizer_strength: float = 1.0,
+ ):
+ super().__init__(
+ solver_name,
+ solver_kwargs=solver_kwargs,
+ )
+ self.regularizer_strength = regularizer_strength
+ self.mask = jnp.asarray(mask)
+
+ @property
+ def mask(self):
+ """Getter for the mask attribute."""
+ return self._mask
+
+ @mask.setter
+ def mask(self, mask: jnp.ndarray):
+ self._check_mask(mask)
+ self._mask = mask
+
+ @staticmethod
+ def _check_mask(mask: jnp.ndarray):
+ """
+ Validate the mask array.
+
+ This method ensures the mask adheres to requirements:
+ - It should be 2-dimensional.
+ - Each element must be either 0 or 1.
+ - Each feature should belong to only one group.
+ - The mask should not be empty.
+ - The mask is an array of float type.
+
+ Raises
+ ------
+ ValueError
+ If any of the above conditions are not met.
+ """
+ if mask.ndim != 2:
+ raise ValueError(
+ "`mask` must be 2-dimensional. "
+ f"{mask.ndim} dimensional mask provided instead!"
+ )
+
+ if mask.shape[0] == 0:
+ raise ValueError(f"Empty mask provided! Mask has shape {mask.shape}.")
+
+ if jnp.any((mask != 1) & (mask != 0)):
+ raise ValueError("Mask elements be 0s and 1s!")
+
+ if mask.sum() == 0:
+ raise ValueError("Empty mask provided!")
+
+ if jnp.any(mask.sum(axis=0) > 1):
+ raise ValueError(
+ "Incorrect group assignment. Some of the features are assigned "
+ "to more then one group."
+ )
+
+ if not jnp.issubdtype(mask.dtype, jnp.floating):
+ raise ValueError(
+ "Mask should be a floating point jnp.ndarray. "
+ f"Data type {mask.dtype} provided instead!"
+ )
+
+ def _get_proximal_operator(
+ self,
+ ) -> ProximalOperator:
+ """
+ Retrieve the proximal operator for Group Lasso regularization.
+
+ Returns
+ -------
+ :
+ The proximal operator, applying Group Lasso regularization to the provided parameters. The
+ intercept term is not regularized.
+ """
+
+ def prox_op(params, regularizer_strength, scaling=1.0):
+ return prox_group_lasso(
+ params, regularizer_strength, mask=self.mask, scaling=scaling
+ )
+
+ return prox_op
diff --git a/src/nemos/sample_points.py b/src/nemos/sample_points.py
index a0c8a496..fccf5bec 100644
--- a/src/nemos/sample_points.py
+++ b/src/nemos/sample_points.py
@@ -1,5 +1,4 @@
-"""Helper functions for generating arrays of sample points, for basis functions.
-"""
+"""Helper functions for generating arrays of sample points, for basis functions."""
import numpy as np
from numpy.typing import NDArray
@@ -39,7 +38,7 @@ def raised_cosine_log(n_basis_funcs: int, window_size: int) -> NDArray:
def raised_cosine_linear(n_basis_funcs: int, window_size: int) -> NDArray:
- """Generate linear-spaced sample points for RaisedCosineBasis
+ """Generate linear-spaced sample points for RaisedCosineBasis.
When used with the RaisedCosineBasis, results in evenly (linear) spaced
cosine bumps.
diff --git a/src/nemos/simulation.py b/src/nemos/simulation.py
new file mode 100644
index 00000000..43680912
--- /dev/null
+++ b/src/nemos/simulation.py
@@ -0,0 +1,152 @@
+"""Utility functions for coupling filter definition."""
+
+
+import numpy as np
+import scipy.stats as sts
+from numpy.typing import NDArray
+
+
+def difference_of_gammas(
+ ws: int,
+ upper_percentile: float = 0.99,
+ inhib_a: float = 1.0,
+ excit_a: float = 2.0,
+ inhib_b: float = 1.0,
+ excit_b: float = 2.0,
+) -> NDArray:
+ r"""Generate coupling filter as a Gamma pdf difference.
+
+ Parameters
+ ----------
+ ws:
+ The window size of the filter.
+ upper_percentile:
+ Upper bound of the gamma range as a percentile. The gamma function
+ will be evaluated over the range [0, ppf(upper_percentile)].
+ inhib_a:
+ The `a` constant for the gamma pdf of the inhibitory part of the filter.
+ excit_a:
+ The `a` constant for the gamma pdf of the excitatory part of the filter.
+ inhib_b:
+ The `b` constant for the gamma pdf of the inhibitory part of the filter.
+ excit_b:
+ The `a` constant for the gamma pdf of the excitatory part of the filter.
+
+ Notes
+ -----
+ The probability density function of a gamma distribution is parametrized as
+ follows$^1$,
+ $$
+ p(x;\; a, b) = \frac{b^a x^{a-1} e^{-x}}{\Gamma(a)},
+ $$
+ where $\Gamma(a)$ refers to the gamma function, see$^1$.
+
+ Returns
+ -------
+ filter:
+ The coupling filter.
+
+ Raises
+ ------
+ ValueError:
+ - If any of the Gamma parameters is lesser or equal to 0.
+ - If the upper_percentile is not in [0, 1).
+
+ References
+ ----------
+ 1. [SciPy Docs - "scipy.stats.gamma"](https://docs.scipy.org/doc/
+ scipy/reference/generated/scipy.stats.gamma.html)
+ """
+ # check that the gamma parameters are positive (scipy returns
+ # nans otherwise but no exception is raised)
+ variables = {
+ "excit_a": excit_a,
+ "inhib_a": inhib_a,
+ "excit_b": excit_b,
+ "inhib_b": inhib_b,
+ }
+ for name, value in variables.items():
+ if value <= 0:
+ raise ValueError(f"Gamma parameter {name} must be >0.")
+ # check for valid pecentile
+ if upper_percentile < 0 or upper_percentile >= 1:
+ raise ValueError(
+ f"upper_percentile should lie in the [0, 1) interval. {upper_percentile} provided instead!"
+ )
+
+ gm_inhibition = sts.gamma(a=inhib_a, scale=1 / inhib_b)
+ gm_excitation = sts.gamma(a=excit_a, scale=1 / excit_b)
+
+ # calculate upper bound for the evaluation
+ xmax = max(gm_inhibition.ppf(upper_percentile), gm_excitation.ppf(upper_percentile))
+ # equi-spaced sample covering the range
+ x = np.linspace(0, xmax, ws)
+
+ # compute difference of gammas & normalize
+ gamma_diff = gm_excitation.pdf(x) - gm_inhibition.pdf(x)
+ gamma_diff = gamma_diff / np.linalg.norm(gamma_diff, ord=2)
+
+ return gamma_diff
+
+
+def regress_filter(coupling_filters: NDArray, eval_basis: NDArray) -> NDArray:
+ """Approximate scipy.stats.gamma based filters with basis function.
+
+ Find the Ordinary Least Squares weights for representing the filters in terms of basis functions.
+
+ Parameters
+ ----------
+ coupling_filters:
+ The coupling filters. Shape (window_size, n_neurons_receiver, n_neurons_sender)
+ eval_basis:
+ The evaluated basis function, shape (window_size, n_basis_funcs)
+
+ Returns
+ -------
+ weights:
+ The weights for each neuron. Shape (n_neurons_receiver, n_neurons_sender, n_basis_funcs)
+
+ Raises
+ ------
+ ValueError
+ - If eval_basis is not two-dimensional
+ - If coupling_filters is not three-dimensional
+ - If window_size differs between eval_basis and coupling_filters
+ """
+ # check shapes
+ if eval_basis.ndim != 2:
+ raise ValueError(
+ "eval_basis must be a 2 dimensional array, "
+ "shape (window_size, n_basis_funcs). "
+ f"{eval_basis.ndim} dimensional array provided instead!"
+ )
+ if coupling_filters.ndim != 3:
+ raise ValueError(
+ "coupling_filters must be a 3 dimensional array, "
+ "shape (window_size, n_neurons, n_neurons). "
+ f"{coupling_filters.ndim} dimensional array provided instead!"
+ )
+
+ ws, n_neurons_receiver, n_neurons_sender = coupling_filters.shape
+
+ # check that window size matches
+ if eval_basis.shape[0] != ws:
+ raise ValueError(
+ "window_size mismatch. The window size of coupling_filters and eval_basis "
+ f"does not match. coupling_filters has a window size of {ws}; "
+ f"eval_basis has a window size of {eval_basis.shape[0]}."
+ )
+
+ # Reshape the coupling_filters for vectorized least-squares
+ filters_reshaped = coupling_filters.reshape(ws, -1)
+
+ # Solve the least squares problem for all filters at once
+ # (vecotrizing the features)
+ weights = np.linalg.lstsq(eval_basis, filters_reshaped, rcond=None)[0]
+
+ # Reshape back to the original dimensions
+ weights = np.transpose(
+ weights.reshape(-1, n_neurons_receiver, n_neurons_sender), axes=(1, 2, 0)
+ )
+
+ return weights
diff --git a/src/nemos/utils.py b/src/nemos/utils.py
index ef87d04e..bd8a3d72 100644
--- a/src/nemos/utils.py
+++ b/src/nemos/utils.py
@@ -1,10 +1,9 @@
-"""Utility functions for data pre-processing
-"""
+"""Utility functions for data pre-processing."""
# required to get ArrayLike to render correctly, unnecessary as of python 3.10
from __future__ import annotations
from functools import partial
-from typing import Iterable, List, Literal, Optional, Union
+from typing import Any, Callable, Iterable, List, Literal, Optional, Union
import jax
import jax.numpy as jnp
@@ -52,8 +51,7 @@ def convolve_1d_trials(
basis_matrix: ArrayLike,
time_series: Union[Iterable[NDArray], NDArray, Iterable[jnp.ndarray], jnp.ndarray],
) -> List[jnp.ndarray]:
- """
- Convolve trial time series with a basis matrix.
+ """Convolve trial time series with a basis matrix.
This function checks if all trials have the same duration. If they do, it uses a fast method
to convolve all trials with the basis matrix at once. If they do not, it falls back to convolving
@@ -82,7 +80,6 @@ def convolve_1d_trials(
- If trials_time_series contains empty trials.
- If the number of time points in each trial is less than the window size.
"""
-
basis_matrix = jnp.asarray(basis_matrix)
# check input size
if basis_matrix.ndim != 2:
@@ -143,6 +140,7 @@ def _pad_dimension(
Add padding to the last dimension of an array based on the convolution type.
This is a helper function used by `nan_pad_conv`, which is the function we expect the user will interact with.
+
Parameters
----------
array:
@@ -218,6 +216,7 @@ def nan_pad_conv(
ValueError
If the window_size is not a positive integer, or if the filter_type is not one of 'causal',
'acausal', or 'anti-causal'. Also raises ValueError if the dimensionality of conv_trials is not as expected.
+
"""
if not isinstance(window_size, int) or window_size <= 0:
raise ValueError(
@@ -340,9 +339,10 @@ def plot_spike_raster(
def row_wise_kron(A: jnp.array, C: jnp.array, jit=False, transpose=True) -> jnp.array:
- """
- Compute the row-wise Kronecker product between two matrices using JAX. See [1]
- for more details on the Kronecker product.
+ r"""Compute the row-wise Kronecker product.
+
+ Compute the row-wise Kronecker product between two matrices using JAX.
+ See [\[1\]](#references) for more details on the Kronecker product.
Parameters
----------
@@ -365,10 +365,11 @@ def row_wise_kron(A: jnp.array, C: jnp.array, jit=False, transpose=True) -> jnp.
This function computes the row-wise Kronecker product between dense matrices A and C
using JAX for automatic differentiation and GPU acceleration.
- .. [1] Petersen, Kaare Brandt, and Michael Syskind Pedersen. "The matrix cookbook."
- Technical University of Denmark 7.15 (2008): 510.
+ References
+ ----------
+ 1. Petersen, Kaare Brandt, and Michael Syskind Pedersen. "The matrix cookbook."
+ Technical University of Denmark 7.15 (2008): 510.
"""
-
if transpose:
A = A.T
C = C.T
@@ -383,3 +384,78 @@ def row_wise_kron(a, c):
K = K.T
return K
+
+
+def check_invalid_entry(array: jnp.ndarray, array_name: str) -> None:
+ """Check if the array has nans or infs.
+
+ Parameters
+ ----------
+ array:
+ The array to be checked.
+ array_name:
+ The array name.
+
+ Raises
+ ------
+ - ValueError: If any entry of `array` is either NaN or inf.
+
+ """
+ if jnp.any(jnp.isinf(array)):
+ raise ValueError(f"Input array '{array_name}' contains Infs!")
+ elif jnp.any(jnp.isnan(array)):
+ raise ValueError(f"Input array '{array_name}' contains NaNs!")
+
+
+def assert_has_attribute(obj: Any, attr_name: str):
+ """Ensure the object has the given attribute."""
+ if not hasattr(obj, attr_name):
+ raise AttributeError(
+ f"The provided object does not have the required `{attr_name}` attribute!"
+ )
+
+
+def assert_is_callable(func: Callable, func_name: str):
+ """Ensure the provided function is callable."""
+ if not callable(func):
+ raise TypeError(f"The `{func_name}` must be a Callable!")
+
+
+def assert_returns_ndarray(
+ func: Callable, inputs: Union[List[jnp.ndarray], List[float]], func_name: str
+):
+ """Ensure the function returns a jax.numpy.ndarray."""
+ array_out = func(*inputs)
+ if not isinstance(array_out, jnp.ndarray):
+ raise TypeError(f"The `{func_name}` must return a jax.numpy.ndarray!")
+
+
+def assert_differentiable(func: Callable, func_name: str):
+ """Ensure the function is differentiable."""
+ try:
+ gradient_fn = jax.grad(func)
+ gradient_fn(jnp.array(1.0))
+ except Exception as e:
+ raise TypeError(f"The `{func_name}` is not differentiable. Error: {str(e)}")
+
+
+def assert_preserve_shape(
+ func: Callable, inputs: List[jnp.ndarray], func_name: str, input_index: int
+):
+ """Check that the function preserve the input shape."""
+ result = func(*inputs)
+ if not result.shape == inputs[input_index].shape:
+ raise ValueError(f"The `{func_name}` must preserve the input array shape!")
+
+
+def assert_scalar_func(func: Callable, inputs: List[jnp.ndarray], func_name: str):
+ """Check that `func` return an array containing a single scalar."""
+ assert_returns_ndarray(func, inputs, func_name)
+ array_out = func(*inputs)
+ try:
+ float(array_out)
+ except TypeError:
+ raise TypeError(
+ f"The `{func_name}` should return a scalar! "
+ f"Array of shape {array_out.shape} returned instead!"
+ )
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 00000000..2bfa9e0a
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,192 @@
+"""
+Testing configurations for the `nemos` library.
+
+This module contains test fixtures required to set up and verify the functionality
+of the modules of the `nemos` library.
+
+Note:
+ This module primarily serves as a utility for test configurations, setting up initial conditions,
+ and loading predefined parameters for testing various functionalities of the `nemos` library.
+"""
+import jax
+import jax.numpy as jnp
+import numpy as np
+import pytest
+
+import nemos as nmo
+
+
+@pytest.fixture
+def poissonGLM_model_instantiation():
+ """Set up a Poisson GLM for testing purposes.
+
+ This fixture initializes a Poisson GLM with random parameters, simulates its response, and
+ returns the test data, expected output, the model instance, true parameters, and the rate
+ of response.
+
+ Returns:
+ tuple: A tuple containing:
+ - X (numpy.ndarray): Simulated input data.
+ - np.random.poisson(rate) (numpy.ndarray): Simulated spike responses.
+ - model (nmo.glm.PoissonGLM): Initialized model instance.
+ - (w_true, b_true) (tuple): True weight and bias parameters.
+ - rate (jax.numpy.ndarray): Simulated rate of response.
+ """
+ np.random.seed(123)
+ X = np.random.normal(size=(100, 1, 5))
+ b_true = np.zeros((1,))
+ w_true = np.random.normal(size=(1, 5))
+ observation_model = nmo.observation_models.PoissonObservations(jnp.exp)
+ regularizer = nmo.regularizer.UnRegularized("GradientDescent", {})
+ model = nmo.glm.GLM(observation_model, regularizer)
+ rate = jax.numpy.exp(jax.numpy.einsum("ik,tik->ti", w_true, X) + b_true[None, :])
+ return X, np.random.poisson(rate), model, (w_true, b_true), rate
+
+
+@pytest.fixture
+def poissonGLM_coupled_model_config_simulate():
+ """Set up a Poisson GLM from a predefined configuration in a json file.
+
+ This fixture reads parameters for a Poisson GLM from a json configuration file, initializes
+ the model accordingly, and returns the model instance with other related parameters.
+
+ Returns:
+ tuple: A tuple containing:
+ - model (nmo.glm.PoissonGLM): Initialized model instance.
+ - coupling_basis (jax.numpy.ndarray): Coupling basis values from the config.
+ - feedforward_input (jax.numpy.ndarray): Feedforward input values from the config.
+ - init_spikes (jax.numpy.ndarray): Initial spike values from the config.
+ - jax.random.PRNGKey(123) (jax.random.PRNGKey): A pseudo-random number generator key.
+ """
+ observations = nmo.observation_models.PoissonObservations(jnp.exp)
+ regularizer = nmo.regularizer.Ridge("BFGS", regularizer_strength=0.1)
+ model = nmo.glm.GLMRecurrent(
+ observation_model=observations, regularizer=regularizer
+ )
+
+ n_neurons, coupling_duration, sim_duration = 2, 100, 1000
+ coupling_filter_bank = np.zeros((coupling_duration, n_neurons, n_neurons))
+ for unit_i in range(n_neurons):
+ for unit_j in range(n_neurons):
+ coupling_filter_bank[:, unit_i, unit_j] = nmo.simulation.difference_of_gammas(
+ coupling_duration
+ )
+ # shrink the filters for simulation stability
+ coupling_filter_bank *= 0.8
+ basis = nmo.basis.RaisedCosineBasisLog(20)
+
+ # approximate the coupling filters in terms of the basis function
+ _, coupling_basis = basis.evaluate_on_grid(coupling_filter_bank.shape[0])
+ coupling_coeff = nmo.simulation.regress_filter(coupling_filter_bank, coupling_basis)
+
+ model.coef_ = jnp.hstack((coupling_coeff.reshape(n_neurons, -1), np.ones((n_neurons, 2))))
+ model.intercept_ = -3 * jnp.ones(n_neurons)
+ feedforward_input = jnp.c_[
+ jnp.cos(jnp.linspace(0, np.pi*4, sim_duration)),
+ jnp.sin(jnp.linspace(0, np.pi*4, sim_duration))
+ ]
+ feedforward_input = jnp.tile(feedforward_input[:, None], (1, n_neurons, 1))
+ init_spikes = jnp.zeros((coupling_duration, n_neurons))
+
+ return (
+ model,
+ coupling_basis,
+ feedforward_input,
+ init_spikes,
+ jax.random.PRNGKey(123),
+ )
+
+
+@pytest.fixture
+def jaxopt_solvers():
+ return [
+ "GradientDescent",
+ "BFGS",
+ "LBFGS",
+ "ScipyMinimize",
+ "NonlinearCG",
+ "ScipyBoundedMinimize",
+ "LBFGSB",
+ "ProximalGradient",
+ ]
+
+
+@pytest.fixture
+def group_sparse_poisson_glm_model_instantiation():
+ """Set up a Poisson GLM for testing purposes with group sparse weights.
+
+ This fixture initializes a Poisson GLM with random, group sparse, parameters, simulates its response, and
+ returns the test data, expected output, the model instance, true parameters, and the rate
+ of response
+
+ Returns:
+ tuple: A tuple containing:
+ - X (numpy.ndarray): Simulated input data.
+ - np.random.poisson(rate) (numpy.ndarray): Simulated spike responses.
+ - model (nmo.glm.PoissonGLM): Initialized model instance.
+ - (w_true, b_true) (tuple): True weight and bias parameters.
+ - rate (jax.numpy.ndarray): Simulated rate of response.
+ """
+ np.random.seed(123)
+ X = np.random.normal(size=(100, 1, 5))
+ b_true = np.zeros((1,))
+ w_true = np.random.normal(size=(1, 5))
+ w_true[0, 1:4] = 0.0
+ mask = np.zeros((2, 5))
+ mask[0, 1:4] = 1
+ mask[1, [0, 4]] = 1
+ observation_model = nmo.observation_models.PoissonObservations(jnp.exp)
+ regularizer = nmo.regularizer.UnRegularized("GradientDescent", {})
+ model = nmo.glm.GLM(observation_model, regularizer)
+ rate = jax.numpy.exp(jax.numpy.einsum("ik,tik->ti", w_true, X) + b_true[None, :])
+ return X, np.random.poisson(rate), model, (w_true, b_true), rate, mask
+
+
+@pytest.fixture
+def example_data_prox_operator():
+ n_neurons = 3
+ n_features = 4
+
+ params = (jnp.ones((n_neurons, n_features)), jnp.zeros(n_neurons))
+ regularizer_strength = 0.1
+ mask = jnp.array([[1, 0, 1, 0], [0, 1, 0, 1]], dtype=jnp.float32)
+ scaling = 0.5
+
+ return params, regularizer_strength, mask, scaling
+
+
+@pytest.fixture
+def poisson_observation_model():
+ return nmo.observation_models.PoissonObservations(jnp.exp)
+
+
+@pytest.fixture
+def ridge_regularizer():
+ return nmo.regularizer.Ridge(solver_name="LBFGS", regularizer_strength=0.1)
+
+
+@pytest.fixture
+def lasso_regularizer():
+ return nmo.regularizer.Lasso(
+ solver_name="ProximalGradient", regularizer_strength=0.1
+ )
+
+
+@pytest.fixture
+def group_lasso_2groups_5features_regularizer():
+ mask = np.zeros((2, 5))
+ mask[0, :2] = 1
+ mask[1, 2:] = 1
+ return nmo.regularizer.GroupLasso(
+ solver_name="ProximalGradient", mask=mask, regularizer_strength=0.1
+ )
+
+
+@pytest.fixture
+def mock_data():
+ return jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]), jnp.array([[1, 2], [3, 4]])
+
+
+@pytest.fixture()
+def glm_class():
+ return nmo.glm.GLM
diff --git a/tests/simulate_coupled_neurons_params.json b/tests/simulate_coupled_neurons_params.json
new file mode 100644
index 00000000..a8465f12
--- /dev/null
+++ b/tests/simulate_coupled_neurons_params.json
@@ -0,0 +1 @@
+{"coef_": [[-0.004372, -0.02786, -0.04582, -0.0588, -0.06539, -0.06396, -0.05328, -0.03192, 0.0002296, 0.04143, 0.08794, 0.1483, 0.2053, 0.2483, 0.2892, 0.3093, 0.2917, 0.2225, 0.07357, -0.2711, -0.006235, -0.01047, 0.02189, 0.058, 0.09002, 0.1118, 0.1209, 0.1167, 0.09909, 0.07044, 0.03448, -0.01565, -0.06823, -0.1128, -0.1655, -0.2176, -0.2621, -0.2982, -0.3255, -0.3449, 0.5, 0.5], [-0.004637, 0.02223, 0.07071, 0.09572, 0.1012, 0.08923, 0.06464, 0.03076, -0.007911, -0.04737, -0.08429, -0.1249, -0.1582, -0.1827, -0.2081, -0.23, -0.2473, -0.2616, -0.2741, -0.287, 0.01127, 0.04864, 0.0544, 0.05082, 0.03975, 0.02393, 0.004725, -0.01763, -0.04202, -0.06744, -0.09269, -0.1231, -0.1522, -0.1763, -0.2051, -0.2348, -0.2629, -0.2896, -0.3149, -0.3389, 0.5, 0.5]], "coupling_basis": [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0024979173609873673, 0.9975020826390129], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.11451325277931029, 0.8854867472206909, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25013898844998006, 0.7498610115500185, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3122501403134024, 0.687749859686596, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.28176761370807446, 0.7182323862919272, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.17383844924397923, 0.8261615507560222, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.04364762794083282, 0.9563523720591665, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9912618171282106, 0.008738182871789013, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7892946476427273, 0.21070535235727128, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3531647741677867, 0.6468352258322151, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.011883820048045501, 0.9881161799519544, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7841665801263835, 0.21583341987361648, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.17688067665784446, 0.8231193233421555, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9253003862638604, 0.0746996137361397, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2549435480705588, 0.7450564519294413, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9205258993369989, 0.07947410066300109, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.16827351931758228, 0.8317264806824178, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7835282009408713, 0.21647179905912872, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.019118847416525586, 0.9808811525834744, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.4372031242218587, 0.5627968757781414, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.9120243919870162, 0.08797560801298382, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.044222034278324274, 0.9557779657216758, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.40793669708774605, 0.5920633029122541, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.8283923698925478, 0.17160763010745222, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.9999802058373224, 1.9794162677666538e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.1458111022283093, 0.8541888977716907, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.4778824971400245, 0.5221175028599756, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.803486827077907, 0.19651317292209308, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.9824675828481839, 0.017532417151816082, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.029720664099906924, 0.9702793359000932, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.19724020774947038, 0.8027597922505296, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.44389603578613035, 0.5561039642138698, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.6909694421867117, 0.30903055781328825, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.8804498633788072, 0.1195501366211929, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.9828262050955638, 0.017173794904436157, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.005816278861877466, 0.9941837211381226, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.07171948190677246, 0.9282805180932275, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.19211081158089233, 0.8078891884191077, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.3422365913893123, 0.6577634086106878, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.49997219806462273, 0.5000278019353773, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.6481581380891199, 0.3518418619108801, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.775227808426499, 0.22477219157350103, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.8747644272334134, 0.12523557276658664, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.9445228823471115, 0.05547711765288865, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.9852942394771702, 0.014705760522829736, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.9998405276097415, 0.00015947239025848603, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.00798856965539202, 0.9920114303446079, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03392307742054024, 0.9660769225794598, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.07373523476821137, 0.9262647652317886, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.12352988337197751, 0.8764701166280225, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.17990211564285485, 0.8200978843571451, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2399997347398921, 0.7600002652601079, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3015222924967669, 0.6984777075032332, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.36268149196393995, 0.63731850803606, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.42214108290743424, 0.5778589170925659, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.47894873221112266, 0.5210512677888774, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5324679173051469, 0.46753208269485313, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5823146093533313, 0.4176853906466687, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6283012081735033, 0.3716987918264968, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6703886551778314, 0.32961134482216864, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7086466881407022, 0.2913533118592979, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7432216468423799, 0.25677835315762026, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7743109612271127, 0.22568903877288732, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.802143356101582, 0.197856643898418, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.82696381862707, 0.17303618137292998, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8490224486822571, 0.15097755131774288, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8685664156253453, 0.13143358437465474, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8858343578296817, 0.11416564217031833, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9010526715389762, 0.09894732846102389, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9144332365128198, 0.08556676348718023, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9261722145965264, 0.07382778540347357, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9364496329422705, 0.06355036705772948, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9454295266061546, 0.05457047339384541, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9532604668007324, 0.04673953319926766, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9600763426393057, 0.039923657360694254, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9659972972699125, 0.03400270273008754, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.971130745291511, 0.028869254708488945, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.975572418558468, 0.024427581441531954, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9794074030288873, 0.020592596971112653, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9827111411428311, 0.017288858857168965, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9855503831123861, 0.014449616887613925, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9879840771076767, 0.012015922892323394, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9900641931482845, 0.009935806851715523, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9918364789707291, 0.008163521029270815, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9933411485659462, 0.006658851434053759, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9946135057219054, 0.005386494278094567, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9956845059646938, 0.004315494035306178, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9965812609202838, 0.0034187390797163486, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.997327489436671, 0.002672510563328956, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9979439199017871, 0.002056080098212898, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9984486481342357, 0.0015513518657642722, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9988574550621354, 0.0011425449378646424, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9991840881776304, 0.0008159118223696749, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.999440510488429, 0.0005594895115710874, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9996371204027914, 0.00036287959720865404, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.999782945694725, 0.00021705430527496627, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9998858144113889, 0.00011418558861114869, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9999525053112863, 4.7494688713622946e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.99998888016377, 1.1119836230089053e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], "feedforward_input": [[[0.0, 1.0], [0.0, 1.0]], [[0.012578617838741058, 0.9999208860571255], [0.012578617838741058, 0.9999208860571255]], [[0.025155245389375847, 0.9996835567465339], [0.025155245389375847, 0.9996835567465339]], [[0.03772789267871718, 0.99928804962034], [0.03772789267871718, 0.99928804962034]], [[0.05029457036336618, 0.9987344272588006], [0.05029457036336618, 0.9987344272588006]], [[0.06285329004448194, 0.9980227772604111], [0.06285329004448194, 0.9980227772604111]], [[0.07540206458240159, 0.9971532122280464], [0.07540206458240159, 0.9971532122280464]], [[0.08793890841106125, 0.9961258697511429], [0.08793890841106125, 0.9961258697511429]], [[0.10046183785216795, 0.9949409123839288], [0.10046183785216795, 0.9949409123839288]], [[0.11296887142907283, 0.9935985276197029], [0.11296887142907283, 0.9935985276197029]], [[0.12545803018029603, 0.9920989278611685], [0.12545803018029603, 0.9920989278611685]], [[0.13792733797265358, 0.9904423503868246], [0.13792733797265358, 0.9904423503868246]], [[0.1503748218139367, 0.9886290573134227], [0.1503748218139367, 0.9886290573134227]], [[0.1627985121650943, 0.986659335554492], [0.1627985121650943, 0.986659335554492]], [[0.17519644325186898, 0.984533496774942], [0.17519644325186898, 0.984533496774942]], [[0.18756665337583714, 0.9822518773417481], [0.18756665337583714, 0.9822518773417481]], [[0.19990718522480458, 0.9798148382707295], [0.19990718522480458, 0.9798148382707295]], [[0.21221608618250787, 0.9772227651694256], [0.21221608618250787, 0.9772227651694256]], [[0.22449140863757258, 0.9744760681760832], [0.22449140863757258, 0.9744760681760832]], [[0.23673121029167973, 0.9715751818947602], [0.23673121029167973, 0.9715751818947602]], [[0.2489335544668916, 0.9685205653265598], [0.2489335544668916, 0.9685205653265598]], [[0.2610965104120882, 0.9653127017970033], [0.2610965104120882, 0.9653127017970033]], [[0.27321815360846585, 0.9619520988795548], [0.27321815360846585, 0.9619520988795548]], [[0.28529656607404974, 0.9584392883153087], [0.28529656607404974, 0.9584392883153087]], [[0.2973298366671723, 0.9547748259288535], [0.2973298366671723, 0.9547748259288535]], [[0.30931606138886886, 0.9509592915403253], [0.30931606138886886, 0.9509592915403253]], [[0.32125334368414366, 0.9469932888736633], [0.32125334368414366, 0.9469932888736633]], [[0.33313979474205757, 0.9428774454610842], [0.33313979474205757, 0.9428774454610842]], [[0.34497353379459045, 0.9386124125437894], [0.34497353379459045, 0.9386124125437894]], [[0.3567526884142317, 0.9341988649689198], [0.3567526884142317, 0.9341988649689198]], [[0.3684753948102499, 0.9296375010827771], [0.3684753948102499, 0.9296375010827771]], [[0.38013979812359666, 0.924929042620325], [0.38013979812359666, 0.924929042620325]], [[0.3917440527203973, 0.9200742345909914], [0.3917440527203973, 0.9200742345909914]], [[0.4032863224839812, 0.915073845160786], [0.4032863224839812, 0.915073845160786]], [[0.41476478110540693, 0.9099286655307568], [0.41476478110540693, 0.9099286655307568]], [[0.4261776123724353, 0.9046395098117981], [0.4261776123724353, 0.9046395098117981]], [[0.4375230104569043, 0.8992072148958368], [0.4375230104569043, 0.8992072148958368]], [[0.4487991802004621, 0.8936326403234123], [0.4487991802004621, 0.8936326403234123]], [[0.46000433739861096, 0.887916668147673], [0.46000433739861096, 0.887916668147673]], [[0.47113670908301786, 0.8820602027948115], [0.47113670908301786, 0.8820602027948115]], [[0.4821945338020477, 0.8760641709209582], [0.4821945338020477, 0.8760641709209582]], [[0.4931760618994744, 0.8699295212655597], [0.4931760618994744, 0.8699295212655597]], [[0.5040795557913246, 0.8636572245012607], [0.5040795557913246, 0.8636572245012607]], [[0.5149032902408126, 0.8572482730803168], [0.5149032902408126, 0.8572482730803168]], [[0.5256455526313207, 0.8507036810775614], [0.5256455526313207, 0.8507036810775614]], [[0.5363046432373825, 0.8440244840299503], [0.5363046432373825, 0.8440244840299503]], [[0.5468788754936273, 0.8372117387727107], [0.5468788754936273, 0.8372117387727107]], [[0.5573665762616421, 0.8302665232721208], [0.5573665762616421, 0.8302665232721208]], [[0.5677660860947078, 0.8231899364549453], [0.5677660860947078, 0.8231899364549453]], [[0.5780757595003707, 0.8159830980345546], [0.5780757595003707, 0.8159830980345546]], [[0.588293965200805, 0.8086471483337551], [0.588293965200805, 0.8086471483337551]], [[0.5984190863909268, 0.8011832481043575], [0.5984190863909268, 0.8011832481043575]], [[0.608449520994217, 0.7935925783435149], [0.608449520994217, 0.7935925783435149]], [[0.6183836819162153, 0.7858763401068549], [0.6183836819162153, 0.7858763401068549]], [[0.6282199972956423, 0.7780357543184395], [0.6282199972956423, 0.7780357543184395]], [[0.6379569107531118, 0.7700720615775812], [0.6379569107531118, 0.7700720615775812]], [[0.647592881637394, 0.7619865219625451], [0.647592881637394, 0.7619865219625451]], [[0.6571263852691885, 0.7537804148311695], [0.6571263852691885, 0.7537804148311695]], [[0.666555913182372, 0.7454550386184362], [0.666555913182372, 0.7454550386184362]], [[0.675879973362679, 0.7370117106310213], [0.675879973362679, 0.7370117106310213]], [[0.6850970904837809, 0.7284517668388609], [0.6850970904837809, 0.7284517668388609]], [[0.6942058061407225, 0.7197765616637636], [0.6942058061407225, 0.7197765616637636]], [[0.7032046790806838, 0.7109874677651024], [0.7032046790806838, 0.7109874677651024]], [[0.7120922854310254, 0.7020858758226226], [0.7120922854310254, 0.7020858758226226]], [[0.720867218924585, 0.6930731943163971], [0.720867218924585, 0.6930731943163971]], [[0.7295280911221884, 0.6839508493039657], [0.7295280911221884, 0.6839508493039657]], [[0.7380735316323389, 0.6747202841946927], [0.7380735316323389, 0.6747202841946927]], [[0.746502188328052, 0.6653829595213794], [0.746502188328052, 0.6653829595213794]], [[0.7548127275607989, 0.6559403527091677], [0.7548127275607989, 0.6559403527091677]], [[0.7630038343715272, 0.6463939578417693], [0.7630038343715272, 0.6463939578417693]], [[0.7710742126987247, 0.6367452854250606], [0.7710742126987247, 0.6367452854250606]], [[0.7790225855834911, 0.6269958621480786], [0.7790225855834911, 0.6269958621480786]], [[0.7868476953715899, 0.6171472306414553], [0.7868476953715899, 0.6171472306414553]], [[0.7945483039124437, 0.6072009492333317], [0.7945483039124437, 0.6072009492333317]], [[0.8021231927550437, 0.5971585917027863], [0.8021231927550437, 0.5971585917027863]], [[0.809571163340744, 0.5870217470308187], [0.809571163340744, 0.5870217470308187]], [[0.8168910371929053, 0.5767920191489297], [0.8168910371929053, 0.5767920191489297]], [[0.8240816561033644, 0.566471026685334], [0.8240816561033644, 0.566471026685334]], [[0.8311418823156935, 0.5560604027088476], [0.8311418823156935, 0.5560604027088476]], [[0.8380705987052264, 0.545561794470492], [0.8380705987052264, 0.545561794470492]], [[0.8448667089558177, 0.5349768631428518], [0.8448667089558177, 0.5349768631428518]], [[0.8515291377333112, 0.5243072835572319], [0.8515291377333112, 0.5243072835572319]], [[0.8580568308556875, 0.5135547439386516], [0.8580568308556875, 0.5135547439386516]], [[0.8644487554598649, 0.5027209456387218], [0.8644487554598649, 0.5027209456387218]], [[0.8707039001651274, 0.4918076028664418], [0.8707039001651274, 0.4918076028664418]], [[0.8768212752331536, 0.4808164424169648], [0.8768212752331536, 0.4808164424169648]], [[0.8827999127246196, 0.4697492033983709], [0.8827999127246196, 0.4697492033983709]], [[0.8886388666523558, 0.45860763695649104], [0.8886388666523558, 0.45860763695649104]], [[0.8943372131310272, 0.4473935059978269], [0.8943372131310272, 0.4473935059978269]], [[0.8998940505233182, 0.4361085849106111], [0.8998940505233182, 0.4361085849106111]], [[0.9053084995825966, 0.42475465928404793], [0.9053084995825966, 0.42475465928404793]], [[0.9105797035920355, 0.4133335256257842], [0.9105797035920355, 0.4133335256257842]], [[0.9157068285001692, 0.4018469910776512], [0.9157068285001692, 0.4018469910776512]], [[0.920689063052863, 0.3902968731297256], [0.920689063052863, 0.3902968731297256]], [[0.9255256189216778, 0.3786849993327503], [0.9255256189216778, 0.3786849993327503]], [[0.9302157308286042, 0.3670132070089654], [0.9302157308286042, 0.3670132070089654]], [[0.934758656667151, 0.35528334296139374], [0.934758656667151, 0.35528334296139374]], [[0.9391536776197676, 0.34349726318162344], [0.9391536776197676, 0.34349726318162344]], [[0.9434000982715812, 0.3316568325561391], [0.9434000982715812, 0.3316568325561391]], [[0.9474972467204298, 0.31976392457124536], [0.9474972467204298, 0.31976392457124536]], [[0.9514444746831766, 0.30782042101662793], [0.9514444746831766, 0.30782042101662793]], [[0.9552411575982869, 0.2958282116876025], [0.9552411575982869, 0.2958282116876025]], [[0.9588866947246497, 0.28378919408609693], [0.9588866947246497, 0.28378919408609693]], [[0.9623805092366334, 0.27170527312041276], [0.9623805092366334, 0.27170527312041276]], [[0.9657220483153546, 0.25957836080381586], [0.9657220483153546, 0.25957836080381586]], [[0.9689107832361495, 0.24741037595200252], [0.9689107832361495, 0.24741037595200252]], [[0.9719462094522335, 0.23520324387949015], [0.9719462094522335, 0.23520324387949015]], [[0.9748278466745341, 0.2229588960949774], [0.9748278466745341, 0.2229588960949774]], [[0.9775552389476861, 0.21067926999572642], [0.9775552389476861, 0.21067926999572642]], [[0.9801279547221765, 0.19836630856101303], [0.9801279547221765, 0.19836630856101303]], [[0.9825455869226277, 0.18602196004469224], [0.9825455869226277, 0.18602196004469224]], [[0.984807753012208, 0.17364817766693041], [0.984807753012208, 0.17364817766693041]], [[0.98691409505316, 0.16124691930515242], [0.98691409505316, 0.16124691930515242]], [[0.9888642797634357, 0.14882014718424924], [0.9888642797634357, 0.14882014718424924]], [[0.9906579985694317, 0.1363698275661], [0.9906579985694317, 0.1363698275661]], [[0.9922949676548136, 0.12389793043845522], [0.9922949676548136, 0.12389793043845522]], [[0.9937749280054242, 0.11140642920322849], [0.9937749280054242, 0.11140642920322849]], [[0.995097645450266, 0.09889730036424986], [0.995097645450266, 0.09889730036424986]], [[0.9962629106985543, 0.08637252321452853], [0.9962629106985543, 0.08637252321452853]], [[0.9972705393728327, 0.07383407952307214], [0.9972705393728327, 0.07383407952307214]], [[0.9981203720381463, 0.06128395322131638], [0.9981203720381463, 0.06128395322131638]], [[0.9988122742272691, 0.04872413008921228], [0.9988122742272691, 0.04872413008921228]], [[0.9993461364619809, 0.036156597441019206], [0.9993461364619809, 0.036156597441019206]], [[0.9997218742703887, 0.023583343810857166], [0.9997218742703887, 0.023583343810857166]], [[0.9999394282002937, 0.011006358638064812], [0.9999394282002937, 0.011006358638064812]], [[0.9999987638285974, -0.001572368047584414], [0.9999987638285974, -0.001572368047584414]], [[0.9998998717667489, -0.014150845940761853], [0.9998998717667489, -0.014150845940761853]], [[0.9996427676622299, -0.026727084775504745], [0.9996427676622299, -0.026727084775504745]], [[0.9992274921960794, -0.03929909464013115], [0.9992274921960794, -0.03929909464013115]], [[0.9986541110764565, -0.0518648862921008], [0.9986541110764565, -0.0518648862921008]], [[0.9979227150282433, -0.06442247147276806], [0.9979227150282433, -0.06442247147276806]], [[0.9970334197786902, -0.07696986322197923], [0.9970334197786902, -0.07696986322197923]], [[0.9959863660391044, -0.08950507619246638], [0.9959863660391044, -0.08950507619246638]], [[0.9947817194825853, -0.10202612696398403], [0.9947817194825853, -0.10202612696398403]], [[0.9934196707178107, -0.11453103435714077], [0.9934196707178107, -0.11453103435714077]], [[0.991900435258877, -0.12701781974687854], [0.991900435258877, -0.12701781974687854]], [[0.9902242534911986, -0.1394845073755453], [0.9902242534911986, -0.1394845073755453]], [[0.9883913906334728, -0.15192912466551547], [0.9883913906334728, -0.15192912466551547]], [[0.9864021366957146, -0.16434970253130593], [0.9864021366957146, -0.16434970253130593]], [[0.9842568064333687, -0.17674427569114137], [0.9842568064333687, -0.17674427569114137]], [[0.9819557392975067, -0.18911088297791617], [0.9819557392975067, -0.18911088297791617]], [[0.9794992993811165, -0.20144756764950503], [0.9794992993811165, -0.20144756764950503]], [[0.9768878753614926, -0.21375237769837538], [0.9768878753614926, -0.21375237769837538]], [[0.9741218804387363, -0.22602336616044894], [0.9741218804387363, -0.22602336616044894]], [[0.9712017522703763, -0.23825859142316483], [0.9712017522703763, -0.23825859142316483]], [[0.9681279529021188, -0.25045611753269825], [0.9681279529021188, -0.25045611753269825]], [[0.9649009686947391, -0.2626140145002818], [0.9649009686947391, -0.2626140145002818]], [[0.9615213102471255, -0.27473035860758266], [0.9615213102471255, -0.27473035860758266]], [[0.9579895123154889, -0.28680323271109], [0.9579895123154889, -0.28680323271109]], [[0.9543061337287488, -0.29883072654545967], [0.9543061337287488, -0.29883072654545967]], [[0.9504717573001116, -0.310810937025771], [0.9504717573001116, -0.310810937025771]], [[0.9464869897348526, -0.32274196854864906], [0.9464869897348526, -0.32274196854864906]], [[0.9423524615343186, -0.33462193329220136], [0.9423524615343186, -0.33462193329220136]], [[0.9380688268961659, -0.3464489515147234], [0.9380688268961659, -0.3464489515147234]], [[0.9336367636108462, -0.3582211518521272], [0.9336367636108462, -0.3582211518521272]], [[0.9290569729543628, -0.369936671614043], [0.9290569729543628, -0.369936671614043]], [[0.9243301795773085, -0.38159365707854837], [0.9243301795773085, -0.38159365707854837]], [[0.9194571313902055, -0.3931902637854787], [0.9194571313902055, -0.3931902637854787]], [[0.9144385994451658, -0.40472465682827324], [0.9144385994451658, -0.40472465682827324]], [[0.9092753778138886, -0.4161950111443075], [0.9092753778138886, -0.4161950111443075]], [[0.9039682834620162, -0.42759951180366895], [0.9039682834620162, -0.42759951180366895]], [[0.8985181561198674, -0.4389363542963303], [0.8985181561198674, -0.4389363542963303]], [[0.8929258581495686, -0.450203744817673], [0.8929258581495686, -0.450203744817673]], [[0.8871922744086043, -0.46139990055231683], [0.8871922744086043, -0.46139990055231683]], [[0.881318312109807, -0.47252304995621186], [0.881318312109807, -0.47252304995621186]], [[0.8753049006778131, -0.4835714330369443], [0.8753049006778131, -0.4835714330369443]], [[0.869152991601999, -0.4945433016322186], [0.869152991601999, -0.4945433016322186]], [[0.8628635582859312, -0.5054369196864643], [0.8628635582859312, -0.5054369196864643]], [[0.856437595893346, -0.5162505635255284], [0.856437595893346, -0.5162505635255284]], [[0.8498761211906867, -0.5269825221294092], [0.8498761211906867, -0.5269825221294092]], [[0.8431801723862224, -0.5376310974029872], [0.8431801723862224, -0.5376310974029872]], [[0.8363508089657762, -0.5481946044447097], [0.8363508089657762, -0.5481946044447097]], [[0.8293891115250829, -0.5586713718131919], [0.8293891115250829, -0.5586713718131919]], [[0.8222961815988096, -0.5690597417916836], [0.8222961815988096, -0.5690597417916836]], [[0.8150731414862624, -0.5793580706503667], [0.8150731414862624, -0.5793580706503667]], [[0.8077211340738071, -0.5895647289064391], [0.8077211340738071, -0.5895647289064391]], [[0.800241322654032, -0.5996781015819448], [0.800241322654032, -0.5996781015819448]], [[0.7926348907416848, -0.6096965884593069], [0.7926348907416848, -0.6096965884593069]], [[0.7849030418864046, -0.6196186043345285], [0.7849030418864046, -0.6196186043345285]], [[0.7770469994822886, -0.6294425792680156], [0.7770469994822886, -0.6294425792680156]], [[0.769068006574317, -0.6391669588329847], [0.769068006574317, -0.6391669588329847]], [[0.7609673256616678, -0.648790204361417], [0.7609673256616678, -0.648790204361417]], [[0.7527462384979551, -0.6583107931875185], [0.7527462384979551, -0.6583107931875185]], [[0.744406045888419, -0.6677272188886485], [0.744406045888419, -0.6677272188886485]], [[0.7359480674841035, -0.6770379915236763], [0.7359480674841035, -0.6770379915236763]], [[0.7273736415730488, -0.6862416378687335], [0.7273736415730488, -0.6862416378687335]], [[0.7186841248685385, -0.6953367016503177], [0.7186841248685385, -0.6953367016503177]], [[0.7098808922944289, -0.7043217437757161], [0.7098808922944289, -0.7043217437757161]], [[0.7009653367675978, -0.7131953425607098], [0.7009653367675978, -0.7131953425607098]], [[0.6919388689775463, -0.7219560939545244], [0.6919388689775463, -0.7219560939545244]], [[0.6828029171631891, -0.7306026117619886], [0.6828029171631891, -0.7306026117619886]], [[0.673558926886866, -0.739133527862871], [0.673558926886866, -0.739133527862871]], [[0.6642083608056142, -0.7475474924283534], [0.6642083608056142, -0.7475474924283534]], [[0.6547526984397353, -0.7558431741346118], [0.6547526984397353, -0.7558431741346118]], [[0.6451934359386937, -0.764019260373469], [0.6451934359386937, -0.764019260373469]], [[0.6355320858443845, -0.7720744574600859], [0.6355320858443845, -0.7720744574600859]], [[0.6257701768518059, -0.7800074908376582], [0.6257701768518059, -0.7800074908376582]], [[0.6159092535671797, -0.7878171052790867], [0.6159092535671797, -0.7878171052790867]], [[0.6059508762635484, -0.7955020650855897], [0.6059508762635484, -0.7955020650855897]], [[0.5958966206338979, -0.8030611542822255], [0.5958966206338979, -0.8030611542822255]], [[0.5857480775418397, -0.8104931768102919], [0.5857480775418397, -0.8104931768102919]], [[0.5755068527698903, -0.8177969567165775], [0.5755068527698903, -0.8177969567165775]], [[0.5651745667653929, -0.8249713383394301], [0.5651745667653929, -0.8249713383394301]], [[0.5547528543841173, -0.8320151864916135], [0.5547528543841173, -0.8320151864916135]], [[0.5442433646315792, -0.8389273866399272], [0.5442433646315792, -0.8389273866399272]], [[0.5336477604021226, -0.8457068450815559], [0.5336477604021226, -0.8457068450815559]], [[0.5229677182158028, -0.8523524891171238], [0.5229677182158028, -0.8523524891171238]], [[0.5122049279531147, -0.8588632672204258], [0.5122049279531147, -0.8588632672204258]], [[0.5013610925876063, -0.865238149204808], [0.5013610925876063, -0.865238149204808]], [[0.49043792791642066, -0.8714761263861723], [0.49043792791642066, -0.8714761263861723]], [[0.47943716228880995, -0.8775762117425775], [0.47943716228880995, -0.8775762117425775]], [[0.4683605363326608, -0.8835374400704151], [0.4683605363326608, -0.8835374400704151]], [[0.4572098026790794, -0.8893588681371302], [0.4572098026790794, -0.8893588681371302]], [[0.44598672568507636, -0.8950395748304677], [0.44598672568507636, -0.8950395748304677]], [[0.4346930811543961, -0.9005786613042182], [0.4346930811543961, -0.9005786613042182]], [[0.4233306560565345, -0.9059752511204399], [0.4233306560565345, -0.9059752511204399]], [[0.4119012482439928, -0.9112284903881356], [0.4119012482439928, -0.9112284903881356]], [[0.40040666616780407, -0.916337547898363], [0.40040666616780407, -0.916337547898363]], [[0.3888487285913878, -0.9213016152557539], [0.3888487285913878, -0.9213016152557539]], [[0.37722926430277026, -0.9261199070064258], [0.37722926430277026, -0.9261199070064258]], [[0.36555011182521946, -0.9307916607622618], [0.36555011182521946, -0.9307916607622618]], [[0.3538131191263388, -0.9353161373215428], [0.3538131191263388, -0.9353161373215428]], [[0.3420201433256689, -0.9396926207859083], [0.3420201433256689, -0.9396926207859083]], [[0.330173050400837, -0.9439204186736329], [0.330173050400837, -0.9439204186736329]], [[0.3182737148923088, -0.9479988620291954], [0.3182737148923088, -0.9479988620291954]], [[0.3063240196067838, -0.9519273055291264], [0.3063240196067838, -0.9519273055291264]], [[0.29432585531928224, -0.9557051275841167], [0.29432585531928224, -0.9557051275841167]], [[0.2822811204739722, -0.9593317304373701], [0.2822811204739722, -0.9593317304373701]], [[0.27019172088378224, -0.9628065402591843], [0.27019172088378224, -0.9628065402591843]], [[0.25805956942885044, -0.9661290072377479], [0.25805956942885044, -0.9661290072377479]], [[0.24588658575385056, -0.9692986056661355], [0.24588658575385056, -0.9692986056661355]], [[0.23367469596425278, -0.9723148340254889], [0.23367469596425278, -0.9723148340254889]], [[0.22142583232155955, -0.975177215064372], [0.22142583232155955, -0.975177215064372]], [[0.20914193293756786, -0.977885295874285], [0.20914193293756786, -0.977885295874285]], [[0.19682494146770554, -0.9804386479613267], [0.19682494146770554, -0.9804386479613267]], [[0.18447680680349254, -0.9828368673139948], [0.18447680680349254, -0.9828368673139948]], [[0.17209948276416928, -0.9850795744671115], [0.17209948276416928, -0.9850795744671115]], [[0.15969492778754976, -0.9871664145618657], [0.15969492778754976, -0.9871664145618657]], [[0.14726510462014156, -0.9890970574019613], [0.14726510462014156, -0.9890970574019613]], [[0.1348119800065847, -0.9908711975058636], [0.1348119800065847, -0.9908711975058636]], [[0.12233752437845731, -0.992488554155135], [0.12233752437845731, -0.992488554155135]], [[0.1098437115425002, -0.9939488714388522], [0.1098437115425002, -0.9939488714388522]], [[0.09733251836830287, -0.9952519182940991], [0.09733251836830287, -0.9952519182940991]], [[0.0848059244755095, -0.9963974885425265], [0.0848059244755095, -0.9963974885425265]], [[0.07226591192058739, -0.9973854009229761], [0.07226591192058739, -0.9973854009229761]], [[0.05971446488321034, -0.9982154991201608], [0.05971446488321034, -0.9982154991201608]], [[0.04715356935230619, -0.9988876517893978], [0.04715356935230619, -0.9988876517893978]], [[0.034585212811817465, -0.9994017525773913], [0.034585212811817465, -0.9994017525773913]], [[0.022011383926227784, -0.9997577201390606], [0.022011383926227784, -0.9997577201390606]], [[0.009434072225897046, -0.999955498150411], [0.009434072225897046, -0.999955498150411]], [[-0.0031447322077359985, -0.9999950553174459], [-0.0031447322077359985, -0.9999950553174459]], [[-0.015723039057040564, -0.9998763853811183], [-0.015723039057040564, -0.9998763853811183]], [[-0.02829885808311759, -0.9995995071183217], [-0.02829885808311759, -0.9995995071183217]], [[-0.04087019944071145, -0.9991644643389178], [-0.04087019944071145, -0.9991644643389178]], [[-0.053435073993057226, -0.9985713258788059], [-0.053435073993057226, -0.9985713258788059]], [[-0.06599149362662023, -0.9978201855890307], [-0.06599149362662023, -0.9978201855890307]], [[-0.07853747156566927, -0.996911162320932], [-0.07853747156566927, -0.996911162320932]], [[-0.09107102268664041, -0.9958443999073396], [-0.09107102268664041, -0.9958443999073396]], [[-0.10359016383223883, -0.9946200671398149], [-0.10359016383223883, -0.9946200671398149]], [[-0.11609291412522968, -0.993238357741943], [-0.11609291412522968, -0.993238357741943]], [[-0.12857729528186848, -0.9916994903386808], [-0.12857729528186848, -0.9916994903386808]], [[-0.14104133192491908, -0.9900037084217639], [-0.14104133192491908, -0.9900037084217639]], [[-0.15348305189621594, -0.9881512803111796], [-0.15348305189621594, -0.9881512803111796]], [[-0.16590048656871298, -0.9861424991127116], [-0.16590048656871298, -0.9861424991127116]], [[-0.1782916711579755, -0.9839776826715616], [-0.1782916711579755, -0.9839776826715616]], [[-0.19065464503306404, -0.9816571735220583], [-0.19065464503306404, -0.9816571735220583]], [[-0.20298745202676116, -0.979181338833458], [-0.20298745202676116, -0.979181338833458]], [[-0.2152881407450901, -0.9765505703518493], [-0.2152881407450901, -0.9765505703518493]], [[-0.2275547648760821, -0.9737652843381669], [-0.2275547648760821, -0.9737652843381669]], [[-0.23978538349773562, -0.9708259215023277], [-0.23978538349773562, -0.9708259215023277]], [[-0.25197806138512474, -0.967732946933499], [-0.25197806138512474, -0.967732946933499]], [[-0.2641308693166058, -0.9644868500265071], [-0.2641308693166058, -0.9644868500265071]], [[-0.2762418843790738, -0.9610881444044029], [-0.2762418843790738, -0.9610881444044029]], [[-0.2883091902722216, -0.9575373678371909], [-0.2883091902722216, -0.9575373678371909]], [[-0.3003308776117502, -0.9538350821567405], [-0.3003308776117502, -0.9538350821567405]], [[-0.31230504423148914, -0.9499818731678872], [-0.31230504423148914, -0.9499818731678872]], [[-0.32422979548437053, -0.9459783505557425], [-0.32422979548437053, -0.9459783505557425]], [[-0.33610324454221563, -0.9418251477892251], [-0.33610324454221563, -0.9418251477892251]], [[-0.34792351269428334, -0.9375229220208277], [-0.34792351269428334, -0.9375229220208277]], [[-0.3596887296445355, -0.9330723539826374], [-0.3596887296445355, -0.9330723539826374]], [[-0.3713970338075679, -0.9284741478786258], [-0.3713970338075679, -0.9284741478786258]], [[-0.3830465726031674, -0.9237290312732227], [-0.3830465726031674, -0.9237290312732227]], [[-0.3946355027494405, -0.9188377549761962], [-0.3946355027494405, -0.9188377549761962]], [[-0.406161990554472, -0.9138010929238535], [-0.406161990554472, -0.9138010929238535]], [[-0.41762421220646645, -0.9086198420565822], [-0.41762421220646645, -0.9086198420565822]], [[-0.4290203540623263, -0.9032948221927524], [-0.4290203540623263, -0.9032948221927524]], [[-0.44034861293461913, -0.8978268758989992], [-0.44034861293461913, -0.8978268758989992]], [[-0.4516071963768948, -0.892216868356904], [-0.4516071963768948, -0.892216868356904]], [[-0.46279432296729867, -0.8864656872260989], [-0.46279432296729867, -0.8864656872260989]], [[-0.47390822259044274, -0.8805742425038149], [-0.47390822259044274, -0.8805742425038149]], [[-0.4849471367174873, -0.8745434663808944], [-0.4849471367174873, -0.8745434663808944]], [[-0.495909318684389, -0.8683743130942929], [-0.495909318684389, -0.8683743130942929]], [[-0.5067930339682724, -0.8620677587760915], [-0.5067930339682724, -0.8620677587760915]], [[-0.5175965604618782, -0.8556248012990468], [-0.5175965604618782, -0.8556248012990468]], [[-0.5283181887460511, -0.849046460118698], [-0.5283181887460511, -0.849046460118698]], [[-0.538956222360216, -0.842333776112062], [-0.538956222360216, -0.842333776112062]], [[-0.5495089780708056, -0.8354878114129367], [-0.5495089780708056, -0.8354878114129367]], [[-0.5599747861375949, -0.8285096492438424], [-0.5599747861375949, -0.8285096492438424]], [[-0.5703519905779012, -0.8214003937446254], [-0.5703519905779012, -0.8214003937446254]], [[-0.5806389494286053, -0.814161169797753], [-0.5806389494286053, -0.814161169797753]], [[-0.5908340350059578, -0.8067931228503245], [-0.5908340350059578, -0.8067931228503245]], [[-0.6009356341631226, -0.7992974187328304], [-0.6009356341631226, -0.7992974187328304]], [[-0.6109421485454225, -0.7916752434746857], [-0.6109421485454225, -0.7916752434746857]], [[-0.6208519948432432, -0.7839278031165661], [-0.6208519948432432, -0.7839278031165661]], [[-0.630663605042557, -0.7760563235195791], [-0.630663605042557, -0.7760563235195791]], [[-0.6403754266730258, -0.7680620501712998], [-0.6403754266730258, -0.7680620501712998]], [[-0.6499859230536464, -0.7599462479886977], [-0.6499859230536464, -0.7599462479886977]], [[-0.6594935735358957, -0.7517102011179935], [-0.6594935735358957, -0.7517102011179935]], [[-0.6688968737443391, -0.7433552127314704], [-0.6688968737443391, -0.7433552127314704]], [[-0.6781943358146659, -0.7348826048212762], [-0.6781943358146659, -0.7348826048212762]], [[-0.6873844886291098, -0.7262937179902474], [-0.6873844886291098, -0.7262937179902474]], [[-0.6964658780492216, -0.717589911239788], [-0.6964658780492216, -0.717589911239788]], [[-0.7054370671459529, -0.7087725617548385], [-0.7054370671459529, -0.7087725617548385]], [[-0.7142966364270207, -0.6998430646859656], [-0.7142966364270207, -0.6998430646859656]], [[-0.723043184061509, -0.6908028329286112], [-0.723043184061509, -0.6908028329286112]], [[-0.731675326101678, -0.6816532968995332], [-0.731675326101678, -0.6816532968995332]], [[-0.7401916967019432, -0.6723959043104729], [-0.7401916967019432, -0.6723959043104729]], [[-0.7485909483349904, -0.6630321199390868], [-0.7485909483349904, -0.6630321199390868]], [[-0.7568717520049916, -0.6535634253971795], [-0.7568717520049916, -0.6535634253971795]], [[-0.7650327974578898, -0.6439913188962686], [-0.7650327974578898, -0.6439913188962686]], [[-0.7730727933887175, -0.634317315010528], [-0.7730727933887175, -0.634317315010528]], [[-0.7809904676459172, -0.6245429444371393], [-0.7809904676459172, -0.6245429444371393]], [[-0.788784567432631, -0.6146697537540928], [-0.788784567432631, -0.6146697537540928]], [[-0.7964538595049286, -0.6046993051754759], [-0.7964538595049286, -0.6046993051754759]], [[-0.8039971303669401, -0.5946331763042871], [-0.8039971303669401, -0.5946331763042871]], [[-0.8114131864628653, -0.5844729598828156], [-0.8114131864628653, -0.5844729598828156]], [[-0.8187008543658276, -0.5742202635406243], [-0.8187008543658276, -0.5742202635406243]], [[-0.825858980963543, -0.5638767095401779], [-0.825858980963543, -0.5638767095401779]], [[-0.8328864336407734, -0.5534439345201586], [-0.8328864336407734, -0.5534439345201586]], [[-0.8397821004585396, -0.5429235892364995], [-0.8397821004585396, -0.5429235892364995]], [[-0.8465448903300604, -0.5323173383011922], [-0.8465448903300604, -0.5323173383011922]], [[-0.8531737331933926, -0.521626859918898], [-0.8531737331933926, -0.521626859918898]], [[-0.8596675801807451, -0.5108538456214089], [-0.8596675801807451, -0.5108538456214089]], [[-0.8660254037844384, -0.5000000000000004], [-0.8660254037844384, -0.5000000000000004]], [[-0.872246198019486, -0.4890670404357173], [-0.872246198019486, -0.4890670404357173]], [[-0.8783289785827684, -0.4780566968276366], [-0.8783289785827684, -0.4780566968276366]], [[-0.8842727830087774, -0.46697071131914863], [-0.8842727830087774, -0.46697071131914863]], [[-0.8900766708219056, -0.4558108380223019], [-0.8900766708219056, -0.4558108380223019]], [[-0.895739723685255, -0.4445788427402534], [-0.895739723685255, -0.4445788427402534]], [[-0.9012610455459443, -0.4332765026878693], [-0.9012610455459443, -0.4332765026878693]], [[-0.9066397627768893, -0.4219056062105194], [-0.9066397627768893, -0.4219056062105194]], [[-0.9118750243150336, -0.410467952501114], [-0.9118750243150336, -0.410467952501114]], [[-0.9169660017960133, -0.39896535131541655], [-0.9169660017960133, -0.39896535131541655]], [[-0.921911889685225, -0.38739962268569333], [-0.921911889685225, -0.38739962268569333]], [[-0.9267119054052849, -0.37577259663273255], [-0.9267119054052849, -0.37577259663273255]], [[-0.931365289459854, -0.3640861128762842], [-0.931365289459854, -0.3640861128762842]], [[-0.9358713055538119, -0.3523420205439648], [-0.9358713055538119, -0.3523420205439648]], [[-0.9402292407097588, -0.3405421778786742], [-0.9402292407097588, -0.3405421778786742]], [[-0.9444384053808287, -0.32868845194456947], [-0.9444384053808287, -0.32868845194456947]], [[-0.948498133559795, -0.3167827183316434], [-0.948498133559795, -0.3167827183316434]], [[-0.9524077828844512, -0.30482686085895394], [-0.9524077828844512, -0.30482686085895394]], [[-0.9561667347392507, -0.2928227712765512], [-0.9561667347392507, -0.2928227712765512]], [[-0.959774394353189, -0.28077234896614933], [-0.959774394353189, -0.28077234896614933]], [[-0.9632301908939126, -0.26867750064059465], [-0.9632301908939126, -0.26867750064059465]], [[-0.9665335775580413, -0.25654014004216524], [-0.9665335775580413, -0.25654014004216524]], [[-0.9696840316576876, -0.2443621876397672], [-0.9696840316576876, -0.2443621876397672]], [[-0.97268105470316, -0.2321455703250619], [-0.97268105470316, -0.2321455703250619]], [[-0.9755241724818386, -0.21989222110757806], [-0.9755241724818386, -0.21989222110757806]], [[-0.9782129351332083, -0.2076040788088557], [-0.9782129351332083, -0.2076040788088557]], [[-0.9807469172200395, -0.19528308775567055], [-0.9807469172200395, -0.19528308775567055]], [[-0.9831257177957041, -0.18293119747238726], [-0.9831257177957041, -0.18293119747238726]], [[-0.9853489604676163, -0.17055036237249038], [-0.9853489604676163, -0.17055036237249038]], [[-0.9874162934567888, -0.15814254144934156], [-0.9874162934567888, -0.15814254144934156]], [[-0.9893273896534934, -0.14570969796621222], [-0.9893273896534934, -0.14570969796621222]], [[-0.9910819466690195, -0.1332537991456406], [-0.9910819466690195, -0.1332537991456406]], [[-0.9926796868835203, -0.1207768158581612], [-0.9926796868835203, -0.1207768158581612]], [[-0.9941203574899392, -0.10828072231046196], [-0.9941203574899392, -0.10828072231046196]], [[-0.9954037305340125, -0.09576749573300417], [-0.9954037305340125, -0.09576749573300417]], [[-0.9965296029503367, -0.08323911606717305], [-0.9965296029503367, -0.08323911606717305]], [[-0.9974977965944997, -0.070697565651995], [-0.9974977965944997, -0.070697565651995]], [[-0.9983081582712682, -0.05814482891047624], [-0.9983081582712682, -0.05814482891047624]], [[-0.9989605597588274, -0.04558289203561173], [-0.9989605597588274, -0.04558289203561173]], [[-0.9994548978290693, -0.0330137426761141], [-0.9994548978290693, -0.0330137426761141]], [[-0.9997910942639261, -0.020439369621912166], [-0.9997910942639261, -0.020439369621912166]], [[-0.9999690958677468, -0.007861762489468911], [-0.9999690958677468, -0.007861762489468911]], [[-0.999988874475714, 0.004717088593031313], [-0.999988874475714, 0.004717088593031313]], [[-0.9998504269583004, 0.01729519330057657], [-0.9998504269583004, 0.01729519330057657]], [[-0.9995537752217639, 0.029870561426252256], [-0.9995537752217639, 0.029870561426252256]], [[-0.9990989662046815, 0.04244120319614822], [-0.9990989662046815, 0.04244120319614822]], [[-0.9984860718705224, 0.055005129584192916], [-0.9984860718705224, 0.055005129584192916]], [[-0.9977151891962615, 0.06756035262687816], [-0.9977151891962615, 0.06756035262687816]], [[-0.9967864401570343, 0.08010488573780679], [-0.9967864401570343, 0.08010488573780679]], [[-0.9956999717068378, 0.09263674402202696], [-0.9956999717068378, 0.09263674402202696]], [[-0.9944559557552776, 0.10515394459009784], [-0.9944559557552776, 0.10515394459009784]], [[-0.9930545891403677, 0.11765450687183807], [-0.9930545891403677, 0.11765450687183807]], [[-0.9914960935973849, 0.1301364529297071], [-0.9914960935973849, 0.1301364529297071]], [[-0.9897807157237836, 0.1425978077717702], [-0.9897807157237836, 0.1425978077717702]], [[-0.9879087269401782, 0.1550365996641971], [-0.9879087269401782, 0.1550365996641971]], [[-0.9858804234473959, 0.16745086044324545], [-0.9858804234473959, 0.16745086044324545]], [[-0.9836961261796103, 0.17983862582667898], [-0.9836961261796103, 0.17983862582667898]], [[-0.9813561807535597, 0.19219793572457194], [-0.9813561807535597, 0.19219793572457194]], [[-0.9788609574138615, 0.20452683454945075], [-0.9788609574138615, 0.20452683454945075]], [[-0.9762108509744296, 0.21682337152571898], [-0.9762108509744296, 0.21682337152571898]], [[-0.9734062807560028, 0.22908560099832972], [-0.9734062807560028, 0.22908560099832972]], [[-0.9704476905197971, 0.24131158274063894], [-0.9704476905197971, 0.24131158274063894]], [[-0.9673355483972903, 0.25349938226140434], [-0.9673355483972903, 0.25349938226140434]], [[-0.9640703468161508, 0.2656470711108758], [-0.9640703468161508, 0.2656470711108758]], [[-0.9606526024223212, 0.27775272718593], [-0.9606526024223212, 0.27775272718593]], [[-0.957082855998271, 0.28981443503420057], [-0.957082855998271, 0.28981443503420057]], [[-0.9533616723774295, 0.30183028615715607], [-0.9533616723774295, 0.30183028615715607]], [[-0.9494896403548136, 0.31379837931207794], [-0.9494896403548136, 0.31379837931207794]], [[-0.9454673725938637, 0.3257168208128897], [-0.9454673725938637, 0.3257168208128897]], [[-0.9412955055295036, 0.33758372482979143], [-0.9412955055295036, 0.33758372482979143]], [[-0.9369746992674384, 0.34939721368765], [-0.9369746992674384, 0.34939721368765]], [[-0.9325056374797075, 0.361155418163101], [-0.9325056374797075, 0.361155418163101]], [[-0.9278890272965095, 0.3728564777803084], [-0.9278890272965095, 0.3728564777803084]], [[-0.9231255991943125, 0.3844985411053488], [-0.9231255991943125, 0.3844985411053488]], [[-0.9182161068802741, 0.3960797660391565], [-0.9182161068802741, 0.3960797660391565]], [[-0.9131613271729835, 0.4075983201089958], [-0.9131613271729835, 0.4075983201089958]], [[-0.9079620598795464, 0.41905238075840945], [-0.9079620598795464, 0.41905238075840945]], [[-0.9026191276690343, 0.4304401356355976], [-0.9026191276690343, 0.4304401356355976]], [[-0.8971333759423143, 0.4417597828801825], [-0.8971333759423143, 0.4417597828801825]], [[-0.8915056726982842, 0.4530095314083134], [-0.8915056726982842, 0.4530095314083134]], [[-0.8857369083965297, 0.4641876011960654], [-0.8857369083965297, 0.4641876011960654]], [[-0.8798279958164298, 0.4752922235610892], [-0.8798279958164298, 0.4752922235610892]], [[-0.873779869912729, 0.486321641442466], [-0.873779869912729, 0.486321641442466]], [[-0.8675934876676018, 0.49727410967872326], [-0.8675934876676018, 0.49727410967872326]], [[-0.8612698279392309, 0.5081478952839691], [-0.8612698279392309, 0.5081478952839691]], [[-0.8548098913069261, 0.5189412777220956], [-0.8548098913069261, 0.5189412777220956]], [[-0.8482146999128025, 0.5296525491790203], [-0.8482146999128025, 0.5296525491790203]], [[-0.8414852973000504, 0.5402800148329067], [-0.8414852973000504, 0.5402800148329067]], [[-0.8346227482478176, 0.5508219931223336], [-0.8346227482478176, 0.5508219931223336]], [[-0.8276281386027314, 0.5612768160123647], [-0.8276281386027314, 0.5612768160123647]], [[-0.8205025751070878, 0.5716428292584782], [-0.8205025751070878, 0.5716428292584782]], [[-0.8132471852237334, 0.5819183926683146], [-0.8132471852237334, 0.5819183926683146]], [[-0.8058631169576695, 0.5921018803612005], [-0.8058631169576695, 0.5921018803612005]], [[-0.7983515386744064, 0.6021916810254089], [-0.7983515386744064, 0.6021916810254089]], [[-0.7907136389150943, 0.6121861981731129], [-0.7907136389150943, 0.6121861981731129]], [[-0.7829506262084637, 0.6220838503929953], [-0.7829506262084637, 0.6220838503929953]], [[-0.7750637288796017, 0.6318830716004721], [-0.7750637288796017, 0.6318830716004721]], [[-0.7670541948555989, 0.6415823112854881], [-0.7670541948555989, 0.6415823112854881]], [[-0.7589232914680891, 0.6511800347578556], [-0.7589232914680891, 0.6511800347578556]], [[-0.7506723052527245, 0.6606747233900812], [-0.7506723052527245, 0.6606747233900812]], [[-0.7423025417456096, 0.670064874857657], [-0.7423025417456096, 0.670064874857657]], [[-0.7338153252767281, 0.6793490033767694], [-0.7338153252767281, 0.6793490033767694]], [[-0.7252119987603977, 0.6885256399393918], [-0.7252119987603977, 0.6885256399393918]], [[-0.7164939234827836, 0.6975933325457224], [-0.7164939234827836, 0.6975933325457224]], [[-0.7076624788865049, 0.706550646433932], [-0.7076624788865049, 0.706550646433932]], [[-0.698719062352368, 0.7153961643071813], [-0.698719062352368, 0.7153961643071813]], [[-0.6896650889782625, 0.7241284865578796], [-0.6896650889782625, 0.7241284865578796]], [[-0.6805019913552531, 0.7327462314891391], [-0.6805019913552531, 0.7327462314891391]], [[-0.6712312193409035, 0.7412480355333995], [-0.6712312193409035, 0.7412480355333995]], [[-0.6618542398298681, 0.7496325534681825], [-0.6618542398298681, 0.7496325534681825]], [[-0.6523725365217912, 0.7578984586289408], [-0.6523725365217912, 0.7578984586289408]], [[-0.6427876096865396, 0.7660444431189778], [-0.6427876096865396, 0.7660444431189778]], [[-0.6331009759268216, 0.7740692180163904], [-0.6331009759268216, 0.7740692180163904]], [[-0.623314167938217, 0.7819715135780128], [-0.623314167938217, 0.7819715135780128]], [[-0.6134287342666622, 0.7897500794403256], [-0.6134287342666622, 0.7897500794403256]], [[-0.6034462390634266, 0.7974036848172986], [-0.6034462390634266, 0.7974036848172986]], [[-0.5933682618376209, 0.8049311186951345], [-0.5933682618376209, 0.8049311186951345]], [[-0.5831963972062739, 0.8123311900238854], [-0.5831963972062739, 0.8123311900238854]], [[-0.5729322546420206, 0.819602727905911], [-0.5729322546420206, 0.819602727905911]], [[-0.5625774582184379, 0.826744581781146], [-0.5625774582184379, 0.826744581781146]], [[-0.552133646353071, 0.8337556216091511], [-0.552133646353071, 0.8337556216091511]], [[-0.541602471548191, 0.8406347380479176], [-0.541602471548191, 0.8406347380479176]], [[-0.5309856001293205, 0.8473808426293961], [-0.5309856001293205, 0.8473808426293961]], [[-0.5202847119815792, 0.8539928679317206], [-0.5202847119815792, 0.8539928679317206]], [[-0.5095015002838734, 0.8604697677481075], [-0.5095015002838734, 0.8604697677481075]], [[-0.4986376712409919, 0.8668105172523927], [-0.4986376712409919, 0.8668105172523927]], [[-0.487694943813635, 0.8730141131611879], [-0.487694943813635, 0.8730141131611879]], [[-0.47667504944642797, 0.8790795738926286], [-0.47667504944642797, 0.8790795738926286]], [[-0.4655797317939577, 0.8850059397216871], [-0.4655797317939577, 0.8850059397216871]], [[-0.45441074644487806, 0.890792272932028], [-0.45441074644487806, 0.890792272932028]], [[-0.4431698606441268, 0.8964376579643814], [-0.4431698606441268, 0.8964376579643814]], [[-0.4318588530132981, 0.9019412015614092], [-0.4318588530132981, 0.9019412015614092]], [[-0.4204795132692152, 0.907302032909044], [-0.4204795132692152, 0.907302032909044]], [[-0.4090336419407468, 0.9125193037742757], [-0.4090336419407468, 0.9125193037742757]], [[-0.3975230500839139, 0.9175921886393661], [-0.3975230500839139, 0.9175921886393661]], [[-0.38594955899532896, 0.9225198848324686], [-0.38594955899532896, 0.9225198848324686]], [[-0.3743149999240192, 0.9273016126546322], [-0.3743149999240192, 0.9273016126546322]], [[-0.3626212137816673, 0.9319366155031737], [-0.3626212137816673, 0.9319366155031737]], [[-0.35087005085133094, 0.9364241599913922], [-0.35087005085133094, 0.9364241599913922]], [[-0.3390633704946757, 0.9407635360646108], [-0.3390633704946757, 0.9407635360646108]], [[-0.3272030408577722, 0.9449540571125281], [-0.3272030408577722, 0.9449540571125281]], [[-0.3152909385755031, 0.9489950600778585], [-0.3152909385755031, 0.9489950600778585]], [[-0.3033289484746273, 0.9528859055612465], [-0.3033289484746273, 0.9528859055612465]], [[-0.29131896327554796, 0.9566259779224375], [-0.29131896327554796, 0.9566259779224375]], [[-0.2792628832928309, 0.9602146853776892], [-0.2792628832928309, 0.9602146853776892]], [[-0.26716261613452225, 0.9636514600934084], [-0.26716261613452225, 0.9636514600934084]], [[-0.25502007640031144, 0.9669357582759981], [-0.25502007640031144, 0.9669357582759981]], [[-0.24283718537858734, 0.9700670602579007], [-0.24283718537858734, 0.9700670602579007]], [[-0.23061587074244044, 0.9730448705798238], [-0.23061587074244044, 0.9730448705798238]], [[-0.21835806624464577, 0.975868718069136], [-0.21835806624464577, 0.975868718069136]], [[-0.20606571141169297, 0.9785381559144195], [-0.20606571141169297, 0.9785381559144195]], [[-0.19374075123689813, 0.981052761736168], [-0.19374075123689813, 0.981052761736168]], [[-0.18138513587265162, 0.9834121376536186], [-0.18138513587265162, 0.9834121376536186]], [[-0.16900082032184968, 0.9856159103477083], [-0.16900082032184968, 0.9856159103477083]], [[-0.15658976412855838, 0.9876637311201432], [-0.15658976412855838, 0.9876637311201432]], [[-0.14415393106795907, 0.9895552759485718], [-0.14415393106795907, 0.9895552759485718]], [[-0.13169528883562445, 0.9912902455378553], [-0.13169528883562445, 0.9912902455378553]], [[-0.11921580873617425, 0.9928683653674237], [-0.11921580873617425, 0.9928683653674237]], [[-0.10671746537135988, 0.9942893857347128], [-0.10671746537135988, 0.9942893857347128]], [[-0.0942022363276273, 0.9955530817946745], [-0.0942022363276273, 0.9955530817946745]], [[-0.08167210186320688, 0.9966592535953529], [-0.08167210186320688, 0.9966592535953529]], [[-0.06912904459478485, 0.9976077261095226], [-0.06912904459478485, 0.9976077261095226]], [[-0.056575049183792726, 0.998398349262383], [-0.056575049183792726, 0.998398349262383]], [[-0.04401210202238211, 0.9990309979553044], [-0.04401210202238211, 0.9990309979553044]], [[-0.031442190919121114, 0.9995055720856215], [-0.031442190919121114, 0.9995055720856215]], [[-0.018867304784467676, 0.9998219965624732], [-0.018867304784467676, 0.9998219965624732]], [[-0.006289433316068405, 0.9999802213186832], [-0.006289433316068405, 0.9999802213186832]], [[0.006289433316067026, 0.9999802213186832], [0.006289433316067026, 0.9999802213186832]], [[0.0188673047844663, 0.9998219965624732], [0.0188673047844663, 0.9998219965624732]], [[0.03144219091911974, 0.9995055720856215], [0.03144219091911974, 0.9995055720856215]], [[0.04401210202238073, 0.9990309979553045], [0.04401210202238073, 0.9990309979553045]], [[0.056575049183791346, 0.9983983492623831], [0.056575049183791346, 0.9983983492623831]], [[0.06912904459478347, 0.9976077261095226], [0.06912904459478347, 0.9976077261095226]], [[0.08167210186320639, 0.9966592535953529], [0.08167210186320639, 0.9966592535953529]], [[0.09420223632762592, 0.9955530817946746], [0.09420223632762592, 0.9955530817946746]], [[0.10671746537135851, 0.994289385734713], [0.10671746537135851, 0.994289385734713]], [[0.11921580873617288, 0.9928683653674238], [0.11921580873617288, 0.9928683653674238]], [[0.13169528883562306, 0.9912902455378555], [0.13169528883562306, 0.9912902455378555]], [[0.14415393106795768, 0.9895552759485721], [0.14415393106795768, 0.9895552759485721]], [[0.15658976412855702, 0.9876637311201434], [0.15658976412855702, 0.9876637311201434]], [[0.16900082032184832, 0.9856159103477086], [0.16900082032184832, 0.9856159103477086]], [[0.18138513587265026, 0.9834121376536189], [0.18138513587265026, 0.9834121376536189]], [[0.19374075123689677, 0.9810527617361683], [0.19374075123689677, 0.9810527617361683]], [[0.2060657114116916, 0.9785381559144198], [0.2060657114116916, 0.9785381559144198]], [[0.21835806624464443, 0.9758687180691363], [0.21835806624464443, 0.9758687180691363]], [[0.2306158707424391, 0.9730448705798241], [0.2306158707424391, 0.9730448705798241]], [[0.24283718537858687, 0.9700670602579009], [0.24283718537858687, 0.9700670602579009]], [[0.2550200764003101, 0.9669357582759984], [0.2550200764003101, 0.9669357582759984]], [[0.2671626161345209, 0.9636514600934087], [0.2671626161345209, 0.9636514600934087]], [[0.2792628832928296, 0.9602146853776896], [0.2792628832928296, 0.9602146853776896]], [[0.2913189632755466, 0.956625977922438], [0.2913189632755466, 0.956625977922438]], [[0.30332894847462605, 0.952885905561247], [0.30332894847462605, 0.952885905561247]], [[0.3152909385755018, 0.9489950600778589], [0.3152909385755018, 0.9489950600778589]], [[0.3272030408577709, 0.9449540571125286], [0.3272030408577709, 0.9449540571125286]], [[0.33906337049467444, 0.9407635360646113], [0.33906337049467444, 0.9407635360646113]], [[0.3508700508513296, 0.9364241599913926], [0.3508700508513296, 0.9364241599913926]], [[0.36262121378166595, 0.9319366155031743], [0.36262121378166595, 0.9319366155031743]], [[0.3743149999240179, 0.9273016126546327], [0.3743149999240179, 0.9273016126546327]], [[0.3859495589953277, 0.9225198848324692], [0.3859495589953277, 0.9225198848324692]], [[0.39752305008391264, 0.9175921886393666], [0.39752305008391264, 0.9175921886393666]], [[0.40903364194074554, 0.9125193037742763], [0.40903364194074554, 0.9125193037742763]], [[0.4204795132692139, 0.9073020329090445], [0.4204795132692139, 0.9073020329090445]], [[0.4318588530132969, 0.9019412015614098], [0.4318588530132969, 0.9019412015614098]], [[0.44316986064412556, 0.896437657964382], [0.44316986064412556, 0.896437657964382]], [[0.45441074644487683, 0.8907922729320287], [0.45441074644487683, 0.8907922729320287]], [[0.46557973179395645, 0.8850059397216877], [0.46557973179395645, 0.8850059397216877]], [[0.47667504944642675, 0.8790795738926293], [0.47667504944642675, 0.8790795738926293]], [[0.48769494381363376, 0.8730141131611886], [0.48769494381363376, 0.8730141131611886]], [[0.4986376712409907, 0.8668105172523933], [0.4986376712409907, 0.8668105172523933]], [[0.5095015002838723, 0.8604697677481082], [0.5095015002838723, 0.8604697677481082]], [[0.520284711981578, 0.8539928679317214], [0.520284711981578, 0.8539928679317214]], [[0.5309856001293194, 0.8473808426293968], [0.5309856001293194, 0.8473808426293968]], [[0.5416024715481897, 0.8406347380479183], [0.5416024715481897, 0.8406347380479183]], [[0.5521336463530699, 0.8337556216091518], [0.5521336463530699, 0.8337556216091518]], [[0.5625774582184366, 0.8267445817811466], [0.5625774582184366, 0.8267445817811466]], [[0.5729322546420195, 0.8196027279059118], [0.5729322546420195, 0.8196027279059118]], [[0.5831963972062728, 0.8123311900238863], [0.5831963972062728, 0.8123311900238863]], [[0.5933682618376198, 0.8049311186951352], [0.5933682618376198, 0.8049311186951352]], [[0.6034462390634255, 0.7974036848172994], [0.6034462390634255, 0.7974036848172994]], [[0.6134287342666611, 0.7897500794403265], [0.6134287342666611, 0.7897500794403265]], [[0.6233141679382159, 0.7819715135780135], [0.6233141679382159, 0.7819715135780135]], [[0.6331009759268206, 0.7740692180163913], [0.6331009759268206, 0.7740692180163913]], [[0.6427876096865385, 0.7660444431189787], [0.6427876096865385, 0.7660444431189787]], [[0.6523725365217901, 0.7578984586289417], [0.6523725365217901, 0.7578984586289417]], [[0.6618542398298678, 0.7496325534681827], [0.6618542398298678, 0.7496325534681827]], [[0.6712312193409025, 0.7412480355334005], [0.6712312193409025, 0.7412480355334005]], [[0.6805019913552521, 0.7327462314891401], [0.6805019913552521, 0.7327462314891401]], [[0.6896650889782615, 0.7241284865578805], [0.6896650889782615, 0.7241284865578805]], [[0.698719062352367, 0.7153961643071823], [0.698719062352367, 0.7153961643071823]], [[0.7076624788865039, 0.7065506464339328], [0.7076624788865039, 0.7065506464339328]], [[0.7164939234827827, 0.6975933325457234], [0.7164939234827827, 0.6975933325457234]], [[0.7252119987603968, 0.6885256399393928], [0.7252119987603968, 0.6885256399393928]], [[0.7338153252767271, 0.6793490033767704], [0.7338153252767271, 0.6793490033767704]], [[0.7423025417456087, 0.670064874857658], [0.7423025417456087, 0.670064874857658]], [[0.7506723052527237, 0.6606747233900823], [0.7506723052527237, 0.6606747233900823]], [[0.7589232914680881, 0.6511800347578566], [0.7589232914680881, 0.6511800347578566]], [[0.767054194855598, 0.6415823112854891], [0.767054194855598, 0.6415823112854891]], [[0.7750637288796014, 0.6318830716004724], [0.7750637288796014, 0.6318830716004724]], [[0.7829506262084629, 0.6220838503929964], [0.7829506262084629, 0.6220838503929964]], [[0.7907136389150935, 0.612186198173114], [0.7907136389150935, 0.612186198173114]], [[0.7983515386744056, 0.60219168102541], [0.7983515386744056, 0.60219168102541]], [[0.8058631169576688, 0.5921018803612016], [0.8058631169576688, 0.5921018803612016]], [[0.8132471852237325, 0.5819183926683157], [0.8132471852237325, 0.5819183926683157]], [[0.820502575107087, 0.5716428292584793], [0.820502575107087, 0.5716428292584793]], [[0.8276281386027308, 0.5612768160123658], [0.8276281386027308, 0.5612768160123658]], [[0.8346227482478168, 0.5508219931223347], [0.8346227482478168, 0.5508219931223347]], [[0.8414852973000496, 0.5402800148329078], [0.8414852973000496, 0.5402800148329078]], [[0.8482146999128017, 0.5296525491790214], [0.8482146999128017, 0.5296525491790214]], [[0.8548098913069254, 0.5189412777220967], [0.8548098913069254, 0.5189412777220967]], [[0.8612698279392301, 0.5081478952839703], [0.8612698279392301, 0.5081478952839703]], [[0.8675934876676011, 0.49727410967872443], [0.8675934876676011, 0.49727410967872443]], [[0.8737798699127283, 0.48632164144246715], [0.8737798699127283, 0.48632164144246715]], [[0.8798279958164291, 0.4752922235610904], [0.8798279958164291, 0.4752922235610904]], [[0.8857369083965291, 0.4641876011960666], [0.8857369083965291, 0.4641876011960666]], [[0.8915056726982836, 0.4530095314083147], [0.8915056726982836, 0.4530095314083147]], [[0.8971333759423138, 0.4417597828801838], [0.8971333759423138, 0.4417597828801838]], [[0.9026191276690336, 0.43044013563559885], [0.9026191276690336, 0.43044013563559885]], [[0.9079620598795458, 0.4190523807584107], [0.9079620598795458, 0.4190523807584107]], [[0.9131613271729829, 0.4075983201089971], [0.9131613271729829, 0.4075983201089971]], [[0.9182161068802737, 0.39607976603915773], [0.9182161068802737, 0.39607976603915773]], [[0.9231255991943119, 0.3844985411053501], [0.9231255991943119, 0.3844985411053501]], [[0.9278890272965089, 0.37285647778030967], [0.9278890272965089, 0.37285647778030967]], [[0.932505637479707, 0.36115541816310226], [0.932505637479707, 0.36115541816310226]], [[0.9369746992674379, 0.3493972136876513], [0.9369746992674379, 0.3493972136876513]], [[0.9412955055295031, 0.3375837248297927], [0.9412955055295031, 0.3375837248297927]], [[0.9454673725938633, 0.32571682081289105], [0.9454673725938633, 0.32571682081289105]], [[0.9494896403548132, 0.3137983793120792], [0.9494896403548132, 0.3137983793120792]], [[0.9533616723774291, 0.3018302861571574], [0.9533616723774291, 0.3018302861571574]], [[0.9570828559982706, 0.2898144350342019], [0.9570828559982706, 0.2898144350342019]], [[0.9606526024223209, 0.27775272718593136], [0.9606526024223209, 0.27775272718593136]], [[0.9640703468161504, 0.26564707111087715], [0.9640703468161504, 0.26564707111087715]], [[0.96733554839729, 0.25349938226140567], [0.96733554839729, 0.25349938226140567]], [[0.9704476905197967, 0.24131158274064027], [0.9704476905197967, 0.24131158274064027]], [[0.9734062807560024, 0.22908560099833106], [0.9734062807560024, 0.22908560099833106]], [[0.9762108509744293, 0.21682337152572034], [0.9762108509744293, 0.21682337152572034]], [[0.9788609574138614, 0.20452683454945125], [0.9788609574138614, 0.20452683454945125]], [[0.9813561807535595, 0.1921979357245733], [0.9813561807535595, 0.1921979357245733]], [[0.98369612617961, 0.17983862582668034], [0.98369612617961, 0.17983862582668034]], [[0.9858804234473957, 0.1674508604432468], [0.9858804234473957, 0.1674508604432468]], [[0.987908726940178, 0.15503659966419847], [0.987908726940178, 0.15503659966419847]], [[0.9897807157237833, 0.14259780777177156], [0.9897807157237833, 0.14259780777177156]], [[0.9914960935973847, 0.13013645292970846], [0.9914960935973847, 0.13013645292970846]], [[0.9930545891403676, 0.11765450687183943], [0.9930545891403676, 0.11765450687183943]], [[0.9944559557552775, 0.1051539445900992], [0.9944559557552775, 0.1051539445900992]], [[0.9956999717068375, 0.09263674402202833], [0.9956999717068375, 0.09263674402202833]], [[0.9967864401570342, 0.08010488573780816], [0.9967864401570342, 0.08010488573780816]], [[0.9977151891962615, 0.06756035262687954], [0.9977151891962615, 0.06756035262687954]], [[0.9984860718705224, 0.05500512958419429], [0.9984860718705224, 0.05500512958419429]], [[0.9990989662046814, 0.042441203196148705], [0.9990989662046814, 0.042441203196148705]], [[0.9995537752217638, 0.029870561426253633], [0.9995537752217638, 0.029870561426253633]], [[0.9998504269583004, 0.01729519330057795], [0.9998504269583004, 0.01729519330057795]], [[0.999988874475714, 0.004717088593032691], [0.999988874475714, 0.004717088593032691]], [[0.999969095867747, -0.007861762489467534], [0.999969095867747, -0.007861762489467534]], [[0.9997910942639262, -0.020439369621910786], [0.9997910942639262, -0.020439369621910786]], [[0.9994548978290694, -0.03301374267611272], [0.9994548978290694, -0.03301374267611272]], [[0.9989605597588275, -0.045582892035610355], [0.9989605597588275, -0.045582892035610355]], [[0.9983081582712683, -0.058144828910474865], [0.9983081582712683, -0.058144828910474865]], [[0.9974977965944998, -0.07069756565199363], [0.9974977965944998, -0.07069756565199363]], [[0.9965296029503368, -0.08323911606717167], [0.9965296029503368, -0.08323911606717167]], [[0.9954037305340127, -0.09576749573300279], [0.9954037305340127, -0.09576749573300279]], [[0.9941203574899394, -0.1082807223104606], [0.9941203574899394, -0.1082807223104606]], [[0.9926796868835203, -0.12077681585816072], [0.9926796868835203, -0.12077681585816072]], [[0.9910819466690197, -0.1332537991456392], [0.9910819466690197, -0.1332537991456392]], [[0.9893273896534936, -0.14570969796621086], [0.9893273896534936, -0.14570969796621086]], [[0.9874162934567892, -0.1581425414493393], [0.9874162934567892, -0.1581425414493393]], [[0.9853489604676167, -0.17055036237248902], [0.9853489604676167, -0.17055036237248902]], [[0.9831257177957046, -0.18293119747238504], [0.9831257177957046, -0.18293119747238504]], [[0.9807469172200398, -0.1952830877556692], [0.9807469172200398, -0.1952830877556692]], [[0.9782129351332084, -0.2076040788088552], [0.9782129351332084, -0.2076040788088552]], [[0.9755241724818389, -0.2198922211075767], [0.9755241724818389, -0.2198922211075767]], [[0.9726810547031601, -0.23214557032506142], [0.9726810547031601, -0.23214557032506142]], [[0.9696840316576879, -0.24436218763976586], [0.9696840316576879, -0.24436218763976586]], [[0.9665335775580415, -0.25654014004216474], [0.9665335775580415, -0.25654014004216474]], [[0.9632301908939129, -0.2686775006405933], [0.9632301908939129, -0.2686775006405933]], [[0.9597743943531892, -0.2807723489661489], [0.9597743943531892, -0.2807723489661489]], [[0.9561667347392514, -0.29282277127654904], [0.9561667347392514, -0.29282277127654904]], [[0.9524077828844516, -0.3048268608589526], [0.9524077828844516, -0.3048268608589526]], [[0.9484981335597957, -0.3167827183316413], [0.9484981335597957, -0.3167827183316413]], [[0.9444384053808291, -0.32868845194456814], [0.9444384053808291, -0.32868845194456814]], [[0.9402292407097596, -0.340542177878672], [0.9402292407097596, -0.340542177878672]], [[0.9358713055538124, -0.3523420205439635], [0.9358713055538124, -0.3523420205439635]], [[0.9313652894598542, -0.36408611287628373], [0.9313652894598542, -0.36408611287628373]], [[0.9267119054052854, -0.37577259663273127], [0.9267119054052854, -0.37577259663273127]], [[0.9219118896852252, -0.38739962268569283], [0.9219118896852252, -0.38739962268569283]], [[0.9169660017960138, -0.3989653513154153], [0.9169660017960138, -0.3989653513154153]], [[0.9118750243150339, -0.4104679525011135], [0.9118750243150339, -0.4104679525011135]], [[0.9066397627768898, -0.4219056062105182], [0.9066397627768898, -0.4219056062105182]], [[0.901261045545945, -0.4332765026878681], [0.901261045545945, -0.4332765026878681]], [[0.895739723685256, -0.44457884274025133], [0.895739723685256, -0.44457884274025133]], [[0.8900766708219062, -0.45581083802230066], [0.8900766708219062, -0.45581083802230066]], [[0.8842727830087785, -0.46697071131914664], [0.8842727830087785, -0.46697071131914664]], [[0.878328978582769, -0.47805669682763535], [0.878328978582769, -0.47805669682763535]], [[0.8722461980194871, -0.48906704043571536], [0.8722461980194871, -0.48906704043571536]], [[0.8660254037844392, -0.4999999999999992], [0.8660254037844392, -0.4999999999999992]], [[0.8596675801807453, -0.5108538456214086], [0.8596675801807453, -0.5108538456214086]], [[0.8531737331933934, -0.5216268599188969], [0.8531737331933934, -0.5216268599188969]], [[0.8465448903300608, -0.5323173383011919], [0.8465448903300608, -0.5323173383011919]], [[0.8397821004585404, -0.5429235892364983], [0.8397821004585404, -0.5429235892364983]], [[0.8328864336407736, -0.5534439345201582], [0.8328864336407736, -0.5534439345201582]], [[0.8258589809635439, -0.5638767095401768], [0.8258589809635439, -0.5638767095401768]], [[0.8187008543658284, -0.5742202635406232], [0.8187008543658284, -0.5742202635406232]], [[0.8114131864628666, -0.5844729598828138], [0.8114131864628666, -0.5844729598828138]], [[0.803997130366941, -0.5946331763042861], [0.803997130366941, -0.5946331763042861]], [[0.7964538595049301, -0.6046993051754741], [0.7964538595049301, -0.6046993051754741]], [[0.7887845674326319, -0.6146697537540917], [0.7887845674326319, -0.6146697537540917]], [[0.7809904676459185, -0.6245429444371375], [0.7809904676459185, -0.6245429444371375]], [[0.7730727933887184, -0.6343173150105269], [0.7730727933887184, -0.6343173150105269]], [[0.76503279745789, -0.6439913188962683], [0.76503279745789, -0.6439913188962683]], [[0.7568717520049925, -0.6535634253971785], [0.7568717520049925, -0.6535634253971785]], [[0.7485909483349908, -0.6630321199390865], [0.7485909483349908, -0.6630321199390865]], [[0.7401916967019444, -0.6723959043104716], [0.7401916967019444, -0.6723959043104716]], [[0.7316753261016786, -0.6816532968995326], [0.7316753261016786, -0.6816532968995326]], [[0.7230431840615102, -0.69080283292861], [0.7230431840615102, -0.69080283292861]], [[0.7142966364270213, -0.6998430646859649], [0.7142966364270213, -0.6998430646859649]], [[0.7054370671459542, -0.7087725617548373], [0.7054370671459542, -0.7087725617548373]], [[0.6964658780492222, -0.7175899112397874], [0.6964658780492222, -0.7175899112397874]], [[0.6873844886291115, -0.7262937179902459], [0.6873844886291115, -0.7262937179902459]], [[0.678194335814667, -0.7348826048212753], [0.678194335814667, -0.7348826048212753]], [[0.6688968737443408, -0.7433552127314689], [0.6688968737443408, -0.7433552127314689]], [[0.6594935735358967, -0.7517102011179926], [0.6594935735358967, -0.7517102011179926]], [[0.6499859230536468, -0.7599462479886974], [0.6499859230536468, -0.7599462479886974]], [[0.6403754266730268, -0.7680620501712988], [0.6403754266730268, -0.7680620501712988]], [[0.6306636050425575, -0.7760563235195788], [0.6306636050425575, -0.7760563235195788]], [[0.6208519948432446, -0.7839278031165648], [0.6208519948432446, -0.7839278031165648]], [[0.6109421485454233, -0.7916752434746851], [0.6109421485454233, -0.7916752434746851]], [[0.600935634163124, -0.7992974187328293], [0.600935634163124, -0.7992974187328293]], [[0.5908340350059585, -0.8067931228503239], [0.5908340350059585, -0.8067931228503239]], [[0.5806389494286068, -0.8141611697977519], [0.5806389494286068, -0.8141611697977519]], [[0.570351990577902, -0.8214003937446248], [0.570351990577902, -0.8214003937446248]], [[0.5599747861375968, -0.8285096492438412], [0.5599747861375968, -0.8285096492438412]], [[0.5495089780708068, -0.8354878114129359], [0.5495089780708068, -0.8354878114129359]], [[0.5389562223602165, -0.8423337761120617], [0.5389562223602165, -0.8423337761120617]], [[0.5283181887460523, -0.8490464601186973], [0.5283181887460523, -0.8490464601186973]], [[0.5175965604618786, -0.8556248012990465], [0.5175965604618786, -0.8556248012990465]], [[0.5067930339682736, -0.8620677587760909], [0.5067930339682736, -0.8620677587760909]], [[0.49590931868438975, -0.8683743130942925], [0.49590931868438975, -0.8683743130942925]], [[0.4849471367174889, -0.8745434663808935], [0.4849471367174889, -0.8745434663808935]], [[0.4739082225904436, -0.8805742425038144], [0.4739082225904436, -0.8805742425038144]], [[0.4627943229673003, -0.886465687226098], [0.4627943229673003, -0.886465687226098]], [[0.4516071963768956, -0.8922168683569035], [0.4516071963768956, -0.8922168683569035]], [[0.44034861293462074, -0.8978268758989985], [0.44034861293462074, -0.8978268758989985]], [[0.42902035406232714, -0.903294822192752], [0.42902035406232714, -0.903294822192752]], [[0.4176242122064685, -0.9086198420565812], [0.4176242122064685, -0.9086198420565812]], [[0.4061619905544733, -0.9138010929238529], [0.4061619905544733, -0.9138010929238529]], [[0.3946355027494409, -0.918837754976196], [0.3946355027494409, -0.918837754976196]], [[0.38304657260316866, -0.9237290312732221], [0.38304657260316866, -0.9237290312732221]], [[0.37139703380756833, -0.9284741478786256], [0.37139703380756833, -0.9284741478786256]], [[0.3596887296445368, -0.9330723539826369], [0.3596887296445368, -0.9330723539826369]], [[0.34792351269428423, -0.9375229220208273], [0.34792351269428423, -0.9375229220208273]], [[0.3361032445422173, -0.9418251477892244], [0.3361032445422173, -0.9418251477892244]], [[0.3242297954843714, -0.9459783505557422], [0.3242297954843714, -0.9459783505557422]], [[0.31230504423149086, -0.9499818731678866], [0.31230504423149086, -0.9499818731678866]], [[0.3003308776117511, -0.9538350821567402], [0.3003308776117511, -0.9538350821567402]], [[0.28830919027222335, -0.9575373678371905], [0.28830919027222335, -0.9575373678371905]], [[0.27624188437907515, -0.9610881444044025], [0.27624188437907515, -0.9610881444044025]], [[0.264130869316608, -0.9644868500265066], [0.264130869316608, -0.9644868500265066]], [[0.2519780613851261, -0.9677329469334987], [0.2519780613851261, -0.9677329469334987]], [[0.2397853834977361, -0.9708259215023276], [0.2397853834977361, -0.9708259215023276]], [[0.22755476487608342, -0.9737652843381666], [0.22755476487608342, -0.9737652843381666]], [[0.2152881407450906, -0.9765505703518492], [0.2152881407450906, -0.9765505703518492]], [[0.20298745202676252, -0.9791813388334577], [0.20298745202676252, -0.9791813388334577]], [[0.19065464503306495, -0.9816571735220581], [0.19065464503306495, -0.9816571735220581]], [[0.17829167115797728, -0.9839776826715613], [0.17829167115797728, -0.9839776826715613]], [[0.1659004865687139, -0.9861424991127113], [0.1659004865687139, -0.9861424991127113]], [[0.15348305189621775, -0.9881512803111794], [0.15348305189621775, -0.9881512803111794]], [[0.14104133192492, -0.9900037084217637], [0.14104133192492, -0.9900037084217637]], [[0.12857729528187029, -0.9916994903386805], [0.12857729528187029, -0.9916994903386805]], [[0.11609291412523105, -0.9932383577419429], [0.11609291412523105, -0.9932383577419429]], [[0.10359016383224108, -0.9946200671398147], [0.10359016383224108, -0.9946200671398147]], [[0.09107102268664179, -0.9958443999073395], [0.09107102268664179, -0.9958443999073395]], [[0.07853747156566976, -0.996911162320932], [0.07853747156566976, -0.996911162320932]], [[0.0659914936266216, -0.9978201855890306], [0.0659914936266216, -0.9978201855890306]], [[0.05343507399305771, -0.9985713258788059], [0.05343507399305771, -0.9985713258788059]], [[0.04087019944071283, -0.9991644643389177], [0.04087019944071283, -0.9991644643389177]], [[0.028298858083118522, -0.9995995071183216], [0.028298858083118522, -0.9995995071183216]], [[0.01572303905704239, -0.9998763853811183], [0.01572303905704239, -0.9998763853811183]], [[0.003144732207736932, -0.9999950553174458], [0.003144732207736932, -0.9999950553174458]], [[-0.009434072225895224, -0.999955498150411], [-0.009434072225895224, -0.999955498150411]], [[-0.02201138392622685, -0.9997577201390606], [-0.02201138392622685, -0.9997577201390606]], [[-0.03458521281181564, -0.9994017525773914], [-0.03458521281181564, -0.9994017525773914]], [[-0.04715356935230482, -0.9988876517893979], [-0.04715356935230482, -0.9988876517893979]], [[-0.05971446488320808, -0.9982154991201609], [-0.05971446488320808, -0.9982154991201609]], [[-0.07226591192058601, -0.9973854009229762], [-0.07226591192058601, -0.9973854009229762]], [[-0.08480592447550901, -0.9963974885425265], [-0.08480592447550901, -0.9963974885425265]], [[-0.0973325183683015, -0.9952519182940992], [-0.0973325183683015, -0.9952519182940992]], [[-0.1098437115424997, -0.9939488714388522], [-0.1098437115424997, -0.9939488714388522]], [[-0.12233752437845594, -0.9924885541551351], [-0.12233752437845594, -0.9924885541551351]], [[-0.13481198000658376, -0.9908711975058637], [-0.13481198000658376, -0.9908711975058637]], [[-0.14726510462013975, -0.9890970574019616], [-0.14726510462013975, -0.9890970574019616]], [[-0.15969492778754882, -0.9871664145618658], [-0.15969492778754882, -0.9871664145618658]], [[-0.17209948276416748, -0.9850795744671118], [-0.17209948276416748, -0.9850795744671118]], [[-0.18447680680349163, -0.9828368673139949], [-0.18447680680349163, -0.9828368673139949]], [[-0.19682494146770374, -0.9804386479613271], [-0.19682494146770374, -0.9804386479613271]], [[-0.2091419329375665, -0.9778852958742853], [-0.2091419329375665, -0.9778852958742853]], [[-0.22142583232155733, -0.9751772150643726], [-0.22142583232155733, -0.9751772150643726]], [[-0.23367469596425144, -0.9723148340254892], [-0.23367469596425144, -0.9723148340254892]], [[-0.24588658575385006, -0.9692986056661356], [-0.24588658575385006, -0.9692986056661356]], [[-0.2580595694288491, -0.9661290072377483], [-0.2580595694288491, -0.9661290072377483]], [[-0.2701917208837818, -0.9628065402591844], [-0.2701917208837818, -0.9628065402591844]], [[-0.2822811204739704, -0.9593317304373705], [-0.2822811204739704, -0.9593317304373705]], [[-0.29432585531928135, -0.9557051275841171], [-0.29432585531928135, -0.9557051275841171]], [[-0.30632401960678207, -0.951927305529127], [-0.30632401960678207, -0.951927305529127]], [[-0.31827371489230794, -0.9479988620291956], [-0.31827371489230794, -0.9479988620291956]], [[-0.3301730504008353, -0.9439204186736335], [-0.3301730504008353, -0.9439204186736335]], [[-0.342020143325668, -0.9396926207859086], [-0.342020143325668, -0.9396926207859086]], [[-0.35381311912633706, -0.9353161373215435], [-0.35381311912633706, -0.9353161373215435]], [[-0.3655501118252182, -0.9307916607622624], [-0.3655501118252182, -0.9307916607622624]], [[-0.37722926430276815, -0.9261199070064267], [-0.37722926430276815, -0.9261199070064267]], [[-0.3888487285913865, -0.9213016152557545], [-0.3888487285913865, -0.9213016152557545]], [[-0.4004066661678036, -0.9163375478983632], [-0.4004066661678036, -0.9163375478983632]], [[-0.4119012482439916, -0.9112284903881362], [-0.4119012482439916, -0.9112284903881362]], [[-0.4233306560565341, -0.9059752511204401], [-0.4233306560565341, -0.9059752511204401]], [[-0.4346930811543944, -0.9005786613042189], [-0.4346930811543944, -0.9005786613042189]], [[-0.4459867256850755, -0.8950395748304681], [-0.4459867256850755, -0.8950395748304681]], [[-0.4572098026790778, -0.8893588681371309], [-0.4572098026790778, -0.8893588681371309]], [[-0.46836053633265995, -0.8835374400704156], [-0.46836053633265995, -0.8835374400704156]], [[-0.47943716228880834, -0.8775762117425784], [-0.47943716228880834, -0.8775762117425784]], [[-0.4904379279164198, -0.8714761263861728], [-0.4904379279164198, -0.8714761263861728]], [[-0.5013610925876044, -0.8652381492048091], [-0.5013610925876044, -0.8652381492048091]], [[-0.5122049279531135, -0.8588632672204265], [-0.5122049279531135, -0.8588632672204265]], [[-0.5229677182158008, -0.852352489117125], [-0.5229677182158008, -0.852352489117125]], [[-0.5336477604021214, -0.8457068450815567], [-0.5336477604021214, -0.8457068450815567]], [[-0.5442433646315787, -0.8389273866399275], [-0.5442433646315787, -0.8389273866399275]], [[-0.5547528543841161, -0.8320151864916143], [-0.5547528543841161, -0.8320151864916143]], [[-0.5651745667653925, -0.8249713383394304], [-0.5651745667653925, -0.8249713383394304]], [[-0.5755068527698889, -0.8177969567165786], [-0.5755068527698889, -0.8177969567165786]], [[-0.5857480775418389, -0.8104931768102923], [-0.5857480775418389, -0.8104931768102923]], [[-0.5958966206338965, -0.8030611542822266], [-0.5958966206338965, -0.8030611542822266]], [[-0.6059508762635476, -0.7955020650855904], [-0.6059508762635476, -0.7955020650855904]], [[-0.6159092535671783, -0.7878171052790878], [-0.6159092535671783, -0.7878171052790878]], [[-0.6257701768518052, -0.7800074908376589], [-0.6257701768518052, -0.7800074908376589]], [[-0.6355320858443827, -0.7720744574600873], [-0.6355320858443827, -0.7720744574600873]], [[-0.6451934359386927, -0.76401926037347], [-0.6451934359386927, -0.76401926037347]], [[-0.6547526984397336, -0.7558431741346133], [-0.6547526984397336, -0.7558431741346133]], [[-0.6642083608056132, -0.7475474924283543], [-0.6642083608056132, -0.7475474924283543]], [[-0.6735589268868657, -0.7391335278628713], [-0.6735589268868657, -0.7391335278628713]], [[-0.6828029171631881, -0.7306026117619896], [-0.6828029171631881, -0.7306026117619896]], [[-0.6919388689775459, -0.7219560939545248], [-0.6919388689775459, -0.7219560939545248]], [[-0.7009653367675964, -0.7131953425607112], [-0.7009653367675964, -0.7131953425607112]], [[-0.7098808922944282, -0.7043217437757168], [-0.7098808922944282, -0.7043217437757168]], [[-0.7186841248685372, -0.695336701650319], [-0.7186841248685372, -0.695336701650319]], [[-0.7273736415730482, -0.6862416378687342], [-0.7273736415730482, -0.6862416378687342]], [[-0.7359480674841022, -0.6770379915236775], [-0.7359480674841022, -0.6770379915236775]], [[-0.7444060458884184, -0.6677272188886492], [-0.7444060458884184, -0.6677272188886492]], [[-0.7527462384979536, -0.6583107931875202], [-0.7527462384979536, -0.6583107931875202]], [[-0.7609673256616669, -0.648790204361418], [-0.7609673256616669, -0.648790204361418]], [[-0.7690680065743155, -0.6391669588329865], [-0.7690680065743155, -0.6391669588329865]], [[-0.7770469994822877, -0.6294425792680167], [-0.7770469994822877, -0.6294425792680167]], [[-0.7849030418864043, -0.619618604334529], [-0.7849030418864043, -0.619618604334529]], [[-0.7926348907416839, -0.609696588459308], [-0.7926348907416839, -0.609696588459308]], [[-0.8002413226540318, -0.5996781015819452], [-0.8002413226540318, -0.5996781015819452]], [[-0.807721134073806, -0.5895647289064406], [-0.807721134073806, -0.5895647289064406]], [[-0.8150731414862619, -0.5793580706503675], [-0.8150731414862619, -0.5793580706503675]], [[-0.8222961815988086, -0.5690597417916851], [-0.8222961815988086, -0.5690597417916851]], [[-0.8293891115250823, -0.5586713718131927], [-0.8293891115250823, -0.5586713718131927]], [[-0.8363508089657752, -0.5481946044447112], [-0.8363508089657752, -0.5481946044447112]], [[-0.8431801723862219, -0.537631097402988], [-0.8431801723862219, -0.537631097402988]], [[-0.8498761211906855, -0.5269825221294112], [-0.8498761211906855, -0.5269825221294112]], [[-0.8564375958933453, -0.5162505635255297], [-0.8564375958933453, -0.5162505635255297]], [[-0.8628635582859301, -0.5054369196864662], [-0.8628635582859301, -0.5054369196864662]], [[-0.8691529916019983, -0.49454330163221977], [-0.8691529916019983, -0.49454330163221977]], [[-0.8753049006778127, -0.4835714330369447], [-0.8753049006778127, -0.4835714330369447]], [[-0.8813183121098064, -0.4725230499562131], [-0.8813183121098064, -0.4725230499562131]], [[-0.8871922744086038, -0.46139990055231767], [-0.8871922744086038, -0.46139990055231767]], [[-0.8929258581495678, -0.4502037448176746], [-0.8929258581495678, -0.4502037448176746]], [[-0.898518156119867, -0.43893635429633115], [-0.898518156119867, -0.43893635429633115]], [[-0.9039682834620154, -0.42759951180367056], [-0.9039682834620154, -0.42759951180367056]], [[-0.9092753778138881, -0.4161950111443084], [-0.9092753778138881, -0.4161950111443084]], [[-0.914438599445165, -0.40472465682827513], [-0.914438599445165, -0.40472465682827513]], [[-0.919457131390205, -0.39319026378547983], [-0.919457131390205, -0.39319026378547983]], [[-0.9243301795773077, -0.38159365707855025], [-0.9243301795773077, -0.38159365707855025]], [[-0.9290569729543624, -0.36993667161404425], [-0.9290569729543624, -0.36993667161404425]], [[-0.9336367636108461, -0.3582211518521277], [-0.9336367636108461, -0.3582211518521277]], [[-0.9380688268961654, -0.34644895151472466], [-0.9380688268961654, -0.34644895151472466]], [[-0.9423524615343185, -0.3346219332922018], [-0.9423524615343185, -0.3346219332922018]], [[-0.946486989734852, -0.32274196854865056], [-0.946486989734852, -0.32274196854865056]], [[-0.9504717573001114, -0.31081093702577167], [-0.9504717573001114, -0.31081093702577167]], [[-0.9543061337287484, -0.2988307265454612], [-0.9543061337287484, -0.2988307265454612]], [[-0.9579895123154887, -0.2868032327110909], [-0.9579895123154887, -0.2868032327110909]], [[-0.9615213102471251, -0.27473035860758444], [-0.9615213102471251, -0.27473035860758444]], [[-0.9649009686947388, -0.2626140145002827], [-0.9649009686947388, -0.2626140145002827]], [[-0.9681279529021183, -0.25045611753270025], [-0.9681279529021183, -0.25045611753270025]], [[-0.9712017522703761, -0.23825859142316594], [-0.9712017522703761, -0.23825859142316594]], [[-0.9741218804387358, -0.22602336616045093], [-0.9741218804387358, -0.22602336616045093]], [[-0.9768878753614922, -0.21375237769837674], [-0.9768878753614922, -0.21375237769837674]], [[-0.9794992993811164, -0.2014475676495055], [-0.9794992993811164, -0.2014475676495055]], [[-0.9819557392975065, -0.18911088297791753], [-0.9819557392975065, -0.18911088297791753]], [[-0.9842568064333685, -0.17674427569114207], [-0.9842568064333685, -0.17674427569114207]], [[-0.9864021366957143, -0.1643497025313075], [-0.9864021366957143, -0.1643497025313075]], [[-0.9883913906334727, -0.1519291246655162], [-0.9883913906334727, -0.1519291246655162]], [[-0.9902242534911982, -0.1394845073755471], [-0.9902242534911982, -0.1394845073755471]], [[-0.9919004352588768, -0.12701781974687945], [-0.9919004352588768, -0.12701781974687945]], [[-0.9934196707178105, -0.11453103435714257], [-0.9934196707178105, -0.11453103435714257]], [[-0.9947817194825852, -0.10202612696398496], [-0.9947817194825852, -0.10202612696398496]], [[-0.9959863660391042, -0.08950507619246842], [-0.9959863660391042, -0.08950507619246842]], [[-0.9970334197786901, -0.07696986322198038], [-0.9970334197786901, -0.07696986322198038]], [[-0.9979227150282431, -0.0644224714727701], [-0.9979227150282431, -0.0644224714727701]], [[-0.9986541110764564, -0.051864886292102175], [-0.9986541110764564, -0.051864886292102175]], [[-0.9992274921960794, -0.03929909464013164], [-0.9992274921960794, -0.03929909464013164]], [[-0.9996427676622299, -0.026727084775506123], [-0.9996427676622299, -0.026727084775506123]], [[-0.9998998717667489, -0.014150845940762564], [-0.9998998717667489, -0.014150845940762564]], [[-0.9999987638285974, -0.001572368047586014], [-0.9999987638285974, -0.001572368047586014]], [[-0.9999394282002937, 0.0110063586380641], [-0.9999394282002937, 0.0110063586380641]], [[-0.9997218742703887, 0.02358334381085534], [-0.9997218742703887, 0.02358334381085534]], [[-0.9993461364619809, 0.036156597441018276], [-0.9993461364619809, 0.036156597441018276]], [[-0.9988122742272693, 0.04872413008921046], [-0.9988122742272693, 0.04872413008921046]], [[-0.9981203720381463, 0.06128395322131545], [-0.9981203720381463, 0.06128395322131545]], [[-0.9972705393728328, 0.0738340795230701], [-0.9972705393728328, 0.0738340795230701]], [[-0.9962629106985544, 0.08637252321452737], [-0.9962629106985544, 0.08637252321452737]], [[-0.9950976454502662, 0.09889730036424782], [-0.9950976454502662, 0.09889730036424782]], [[-0.9937749280054243, 0.11140642920322712], [-0.9937749280054243, 0.11140642920322712]], [[-0.9922949676548137, 0.12389793043845473], [-0.9922949676548137, 0.12389793043845473]], [[-0.9906579985694319, 0.1363698275660986], [-0.9906579985694319, 0.1363698275660986]], [[-0.9888642797634358, 0.14882014718424852], [-0.9888642797634358, 0.14882014718424852]], [[-0.9869140950531602, 0.16124691930515087], [-0.9869140950531602, 0.16124691930515087]], [[-0.9848077530122081, 0.17364817766692972], [-0.9848077530122081, 0.17364817766692972]], [[-0.9825455869226281, 0.18602196004469043], [-0.9825455869226281, 0.18602196004469043]], [[-0.9801279547221767, 0.19836630856101212], [-0.9801279547221767, 0.19836630856101212]], [[-0.9775552389476866, 0.21067926999572462], [-0.9775552389476866, 0.21067926999572462]], [[-0.9748278466745344, 0.2229588960949763], [-0.9748278466745344, 0.2229588960949763]], [[-0.9719462094522341, 0.23520324387948816], [-0.9719462094522341, 0.23520324387948816]], [[-0.9689107832361499, 0.24741037595200138], [-0.9689107832361499, 0.24741037595200138]], [[-0.9657220483153551, 0.25957836080381363], [-0.9657220483153551, 0.25957836080381363]], [[-0.9623805092366339, 0.27170527312041143], [-0.9623805092366339, 0.27170527312041143]], [[-0.9588866947246498, 0.2837891940860965], [-0.9588866947246498, 0.2837891940860965]], [[-0.9552411575982872, 0.29582821168760115], [-0.9552411575982872, 0.29582821168760115]], [[-0.9514444746831768, 0.30782042101662727], [-0.9514444746831768, 0.30782042101662727]], [[-0.9474972467204302, 0.31976392457124386], [-0.9474972467204302, 0.31976392457124386]], [[-0.9434000982715814, 0.3316568325561384], [-0.9434000982715814, 0.3316568325561384]], [[-0.9391536776197683, 0.3434972631816217], [-0.9391536776197683, 0.3434972631816217]], [[-0.9347586566671513, 0.35528334296139286], [-0.9347586566671513, 0.35528334296139286]], [[-0.9302157308286049, 0.3670132070089637], [-0.9302157308286049, 0.3670132070089637]], [[-0.9255256189216783, 0.3786849993327492], [-0.9255256189216783, 0.3786849993327492]], [[-0.9206890630528639, 0.3902968731297237], [-0.9206890630528639, 0.3902968731297237]], [[-0.9157068285001696, 0.40184699107765015], [-0.9157068285001696, 0.40184699107765015]], [[-0.9105797035920364, 0.41333352562578207], [-0.9105797035920364, 0.41333352562578207]], [[-0.9053084995825972, 0.4247546592840467], [-0.9053084995825972, 0.4247546592840467]], [[-0.8998940505233184, 0.4361085849106107], [-0.8998940505233184, 0.4361085849106107]], [[-0.8943372131310279, 0.4473935059978257], [-0.8943372131310279, 0.4473935059978257]], [[-0.8886388666523561, 0.45860763695649037], [-0.8886388666523561, 0.45860763695649037]], [[-0.8827999127246203, 0.4697492033983695], [-0.8827999127246203, 0.4697492033983695]], [[-0.8768212752331539, 0.48081644241696414], [-0.8768212752331539, 0.48081644241696414]], [[-0.8707039001651283, 0.49180760286644026], [-0.8707039001651283, 0.49180760286644026]], [[-0.8644487554598653, 0.502720945638721], [-0.8644487554598653, 0.502720945638721]], [[-0.8580568308556884, 0.5135547439386501], [-0.8580568308556884, 0.5135547439386501]], [[-0.8515291377333118, 0.5243072835572309], [-0.8515291377333118, 0.5243072835572309]], [[-0.8448667089558188, 0.53497686314285], [-0.8448667089558188, 0.53497686314285]], [[-0.838070598705227, 0.5455617944704909], [-0.838070598705227, 0.5455617944704909]], [[-0.8311418823156947, 0.5560604027088458], [-0.8311418823156947, 0.5560604027088458]], [[-0.8240816561033651, 0.5664710266853329], [-0.8240816561033651, 0.5664710266853329]], [[-0.8168910371929057, 0.5767920191489293], [-0.8168910371929057, 0.5767920191489293]], [[-0.8095711633407447, 0.5870217470308176], [-0.8095711633407447, 0.5870217470308176]], [[-0.8021231927550442, 0.5971585917027857], [-0.8021231927550442, 0.5971585917027857]], [[-0.7945483039124446, 0.6072009492333305], [-0.7945483039124446, 0.6072009492333305]], [[-0.7868476953715905, 0.6171472306414546], [-0.7868476953715905, 0.6171472306414546]], [[-0.7790225855834922, 0.6269958621480771], [-0.7790225855834922, 0.6269958621480771]], [[-0.7710742126987252, 0.6367452854250599], [-0.7710742126987252, 0.6367452854250599]], [[-0.7630038343715285, 0.6463939578417678], [-0.7630038343715285, 0.6463939578417678]], [[-0.7548127275607995, 0.6559403527091668], [-0.7548127275607995, 0.6559403527091668]], [[-0.7465021883280534, 0.6653829595213779], [-0.7465021883280534, 0.6653829595213779]], [[-0.7380735316323398, 0.6747202841946918], [-0.7380735316323398, 0.6747202841946918]], [[-0.7295280911221899, 0.6839508493039641], [-0.7295280911221899, 0.6839508493039641]], [[-0.7208672189245859, 0.6930731943163961], [-0.7208672189245859, 0.6930731943163961]], [[-0.7120922854310258, 0.7020858758226223], [-0.7120922854310258, 0.7020858758226223]], [[-0.703204679080685, 0.7109874677651012], [-0.703204679080685, 0.7109874677651012]], [[-0.694205806140723, 0.719776561663763], [-0.694205806140723, 0.719776561663763]], [[-0.685097090483782, 0.7284517668388598], [-0.685097090483782, 0.7284517668388598]], [[-0.6758799733626797, 0.7370117106310208], [-0.6758799733626797, 0.7370117106310208]], [[-0.6665559131823733, 0.745455038618435], [-0.6665559131823733, 0.745455038618435]], [[-0.6571263852691893, 0.7537804148311689], [-0.6571263852691893, 0.7537804148311689]], [[-0.6475928816373955, 0.7619865219625438], [-0.6475928816373955, 0.7619865219625438]], [[-0.6379569107531127, 0.7700720615775806], [-0.6379569107531127, 0.7700720615775806]], [[-0.6282199972956439, 0.7780357543184383], [-0.6282199972956439, 0.7780357543184383]], [[-0.6183836819162163, 0.7858763401068541], [-0.6183836819162163, 0.7858763401068541]], [[-0.6084495209942188, 0.7935925783435136], [-0.6084495209942188, 0.7935925783435136]], [[-0.5984190863909279, 0.8011832481043567], [-0.5984190863909279, 0.8011832481043567]], [[-0.5882939652008056, 0.8086471483337546], [-0.5882939652008056, 0.8086471483337546]], [[-0.5780757595003719, 0.8159830980345537], [-0.5780757595003719, 0.8159830980345537]], [[-0.5677660860947084, 0.8231899364549449], [-0.5677660860947084, 0.8231899364549449]], [[-0.5573665762616435, 0.8302665232721198], [-0.5573665762616435, 0.8302665232721198]], [[-0.546878875493628, 0.8372117387727103], [-0.546878875493628, 0.8372117387727103]], [[-0.5363046432373839, 0.8440244840299495], [-0.5363046432373839, 0.8440244840299495]], [[-0.5256455526313215, 0.850703681077561], [-0.5256455526313215, 0.850703681077561]], [[-0.5149032902408143, 0.8572482730803158], [-0.5149032902408143, 0.8572482730803158]], [[-0.5040795557913256, 0.86365722450126], [-0.5040795557913256, 0.86365722450126]], [[-0.49317606189947616, 0.8699295212655587], [-0.49317606189947616, 0.8699295212655587]], [[-0.4821945338020488, 0.8760641709209576], [-0.4821945338020488, 0.8760641709209576]], [[-0.4711367090830182, 0.8820602027948112], [-0.4711367090830182, 0.8820602027948112]], [[-0.46000433739861224, 0.8879166681476723], [-0.46000433739861224, 0.8879166681476723]], [[-0.44879918020046267, 0.893632640323412], [-0.44879918020046267, 0.893632640323412]], [[-0.43752301045690567, 0.8992072148958361], [-0.43752301045690567, 0.8992072148958361]], [[-0.4261776123724359, 0.9046395098117977], [-0.4261776123724359, 0.9046395098117977]], [[-0.4147647811054085, 0.909928665530756], [-0.4147647811054085, 0.909928665530756]], [[-0.403286322483982, 0.9150738451607857], [-0.403286322483982, 0.9150738451607857]], [[-0.39174405272039897, 0.9200742345909907], [-0.39174405272039897, 0.9200742345909907]], [[-0.3801397981235976, 0.9249290426203247], [-0.3801397981235976, 0.9249290426203247]], [[-0.3684753948102517, 0.9296375010827764], [-0.3684753948102517, 0.9296375010827764]], [[-0.3567526884142328, 0.9341988649689195], [-0.3567526884142328, 0.9341988649689195]], [[-0.34497353379459245, 0.9386124125437886], [-0.34497353379459245, 0.9386124125437886]], [[-0.33313979474205874, 0.9428774454610838], [-0.33313979474205874, 0.9428774454610838]], [[-0.3212533436841441, 0.9469932888736632], [-0.3212533436841441, 0.9469932888736632]], [[-0.30931606138887024, 0.9509592915403249], [-0.30931606138887024, 0.9509592915403249]], [[-0.2973298366671729, 0.9547748259288534], [-0.2973298366671729, 0.9547748259288534]], [[-0.28529656607405124, 0.9584392883153082], [-0.28529656607405124, 0.9584392883153082]], [[-0.2732181536084666, 0.9619520988795546], [-0.2732181536084666, 0.9619520988795546]], [[-0.26109651041208987, 0.9653127017970029], [-0.26109651041208987, 0.9653127017970029]], [[-0.24893355446689247, 0.9685205653265596], [-0.24893355446689247, 0.9685205653265596]], [[-0.2367312102916815, 0.9715751818947599], [-0.2367312102916815, 0.9715751818947599]], [[-0.22449140863757358, 0.974476068176083], [-0.22449140863757358, 0.974476068176083]], [[-0.2122160861825098, 0.9772227651694252], [-0.2122160861825098, 0.9772227651694252]], [[-0.19990718522480572, 0.9798148382707292], [-0.19990718522480572, 0.9798148382707292]], [[-0.1875666533758392, 0.9822518773417477], [-0.1875666533758392, 0.9822518773417477]], [[-0.17519644325187023, 0.9845334967749417], [-0.17519644325187023, 0.9845334967749417]], [[-0.16279851216509478, 0.9866593355544919], [-0.16279851216509478, 0.9866593355544919]], [[-0.1503748218139381, 0.9886290573134224], [-0.1503748218139381, 0.9886290573134224]], [[-0.1379273379726542, 0.9904423503868245], [-0.1379273379726542, 0.9904423503868245]], [[-0.12545803018029758, 0.9920989278611683], [-0.12545803018029758, 0.9920989278611683]], [[-0.11296887142907358, 0.9935985276197029], [-0.11296887142907358, 0.9935985276197029]], [[-0.10046183785216964, 0.9949409123839287], [-0.10046183785216964, 0.9949409123839287]], [[-0.08793890841106214, 0.9961258697511428], [-0.08793890841106214, 0.9961258697511428]], [[-0.07540206458240344, 0.9971532122280462], [-0.07540206458240344, 0.9971532122280462]], [[-0.06285329004448297, 0.9980227772604111], [-0.06285329004448297, 0.9980227772604111]], [[-0.05029457036336817, 0.9987344272588005], [-0.05029457036336817, 0.9987344272588005]], [[-0.037727892678718344, 0.99928804962034], [-0.037727892678718344, 0.99928804962034]], [[-0.025155245389377974, 0.9996835567465338], [-0.025155245389377974, 0.9996835567465338]], [[-0.012578617838742366, 0.9999208860571255], [-0.012578617838742366, 0.9999208860571255]], [[-4.898587196589413e-16, 1.0], [-4.898587196589413e-16, 1.0]]], "init_spikes": [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], "n_neurons": 2, "intercept_": [-3.0, -3.0]}
\ No newline at end of file
diff --git a/tests/test_base_class.py b/tests/test_base_class.py
new file mode 100644
index 00000000..56d1b7dc
--- /dev/null
+++ b/tests/test_base_class.py
@@ -0,0 +1,316 @@
+from typing import Literal, Union
+
+import jax
+import jax.numpy as jnp
+import pytest
+from numpy.typing import NDArray
+
+from nemos.base_class import Base, BaseRegressor
+from nemos.utils import check_invalid_entry
+
+
+@pytest.fixture
+def mock_regressor():
+ return MockBaseRegressor()
+
+
+# Sample subclass to test instantiation and methods
+class MockBaseRegressor(BaseRegressor):
+ """
+ Mock implementation of the BaseRegressor abstract class for testing purposes.
+ Implements all required abstract methods as empty methods.
+ """
+
+ def __init__(self, std_param: int = 0):
+ """Initialize a MockBaseRegressor instance with optional standard parameters."""
+ self.std_param = std_param
+ super().__init__()
+
+ def fit(self, X: Union[NDArray, jnp.ndarray], y: Union[NDArray, jnp.ndarray]):
+ pass
+
+ def predict(self, X: Union[NDArray, jnp.ndarray]) -> jnp.ndarray:
+ pass
+
+ def score(
+ self,
+ X: Union[NDArray, jnp.ndarray],
+ y: Union[NDArray, jnp.ndarray],
+ **kwargs,
+ ) -> jnp.ndarray:
+ pass
+
+ def simulate(
+ self,
+ random_key: jax.random.PRNGKeyArray,
+ feed_forward_input: Union[NDArray, jnp.ndarray],
+ **kwargs,
+ ):
+ pass
+
+
+class MockRegressorNested(MockBaseRegressor):
+ def __init__(self, other_param: int, std_param: int = 0):
+ super().__init__(std_param=std_param)
+ self.other_param = MockBaseRegressor(std_param=other_param)
+
+
+class MockBaseRegressorInvalid(BaseRegressor):
+ """
+ Mock model that intentionally doesn't implement all the required abstract methods.
+ Used for testing the instantiation of incomplete concrete classes.
+ """
+
+ def __init__(self, std_param: int = 0):
+ self.std_param = std_param
+ super().__init__()
+
+ def predict(self, X: Union[NDArray, jnp.ndarray]) -> jnp.ndarray:
+ pass
+
+ def score(
+ self, X: Union[NDArray, jnp.ndarray], y: Union[NDArray, jnp.ndarray]
+ ) -> jnp.ndarray:
+ pass
+
+
+class BadEstimator(Base):
+ def __init__(self, param1, *args):
+ super().__init__()
+ pass
+
+
+def test_init():
+ """Test the initialization of the MockBaseRegressor class."""
+ model = MockBaseRegressor(std_param=2)
+ assert model.std_param == 2
+
+
+def test_get_params():
+ """Test the get_params method."""
+ model = MockRegressorNested(other_param=1, std_param=2)
+ params = model.get_params(deep=True)
+ assert params["std_param"] == 2
+ assert params["other_param__std_param"] == 1
+
+
+def set_params():
+ """Test the set_params method."""
+ model = MockBaseRegressor()
+ model.set_params(std_param=1)
+ assert model.std_param == 1
+
+
+def test_invalid_set_params():
+ """Test invalid parameter setting using the set_params method."""
+ model = MockBaseRegressor()
+ with pytest.raises(
+ ValueError, match="Invalid parameter 'invalid_param' for estimator"
+ ):
+ model.set_params(invalid_param="invalid")
+
+
+def test_get_param_names():
+ """Test retrieval of parameter names using the _get_param_names method."""
+ param_names = MockBaseRegressor._get_param_names()
+ # As per your implementation, _get_param_names should capture the constructor arguments
+ assert "std_param" in param_names
+
+
+def test_check_invalid_entry():
+ """Test validation of data arrays."""
+ valid_data = jnp.array([1, 2, 3])
+ invalid_data_nan = jnp.array([1, 2, jnp.nan])
+ invalid_data_inf = jnp.array([1, jnp.inf, 2])
+ check_invalid_entry(valid_data, "valid_data")
+ with pytest.raises(ValueError, match="Input array 'invalid_data_nan' contains NaN"):
+ check_invalid_entry(invalid_data_nan, "invalid_data_nan")
+ with pytest.raises(ValueError, match="Input array 'invalid_data_inf' contains Inf"):
+ check_invalid_entry(invalid_data_inf, "invalid_data_inf")
+
+
+# To ensure abstract methods aren't callable
+def test_abstract_class():
+ """Ensure that abstract methods aren't callable."""
+ with pytest.raises(TypeError, match="Can't instantiate abstract"):
+ BaseRegressor()
+
+
+def test_invalid_concrete_class():
+ """Ensure that classes missing implementation of required abstract methods raise errors."""
+ with pytest.raises(TypeError, match="Can't instantiate abstract"):
+ model = MockBaseRegressorInvalid()
+
+
+def test_preprocess_fit(mock_data, mock_regressor):
+ X, y = mock_data
+ X_out, y_out, params_out = mock_regressor._preprocess_fit(X, y)
+ assert X_out.shape == X.shape
+ assert y_out.shape == y.shape
+ assert params_out[0].shape == (2, 2) # Mock data shapes
+ assert params_out[1].shape == (2,)
+
+
+def test_preprocess_fit_empty_data(mock_regressor):
+ """Test behavior with empty data input."""
+ X, y = jnp.array([[]]), jnp.array([])
+ with pytest.raises(ValueError):
+ mock_regressor._preprocess_fit(X, y)
+
+
+def test_preprocess_fit_mismatched_shapes(mock_regressor):
+ """Test behavior with mismatched X and y shapes."""
+ X = jnp.array([[1, 2], [3, 4]])
+ y = jnp.array([1, 2, 3])
+ with pytest.raises(ValueError):
+ mock_regressor._preprocess_fit(X, y)
+
+
+def test_preprocess_fit_invalid_datatypes(mock_regressor):
+ """Test behavior with invalid data types."""
+ X = "invalid_data_type"
+ y = "invalid_data_type"
+ with pytest.raises(ValueError):
+ mock_regressor._preprocess_fit(X, y)
+
+
+def test_preprocess_fit_with_nan_in_X(mock_regressor):
+ """Test behavior with NaN values in data."""
+ X = jnp.array([[[1, 2], [jnp.nan, 4]]])
+ y = jnp.array([[1, 2]])
+ with pytest.raises(ValueError, match="Input array .+ contains"):
+ mock_regressor._preprocess_fit(X, y)
+
+
+def test_preprocess_fit_with_inf_in_X(mock_regressor):
+ """Test behavior with inf values in data."""
+ X = jnp.array([[[1, 2], [jnp.inf, 4]]])
+ y = jnp.array([[1, 2]])
+ with pytest.raises(ValueError, match="Input array .+ contains"):
+ mock_regressor._preprocess_fit(X, y)
+
+
+def test_preprocess_fit_with_nan_in_y(mock_regressor):
+ """Test behavior with NaN values in data."""
+ X = jnp.array([[[1, 2], [2, 4]]])
+ y = jnp.array([[1, jnp.nan]])
+ with pytest.raises(ValueError, match="Input array .+ contains"):
+ mock_regressor._preprocess_fit(X, y)
+
+
+def test_preprocess_fit_with_inf_in_y(mock_regressor):
+ """Test behavior with inf values in data."""
+ X = jnp.array([[[1, 2], [2, 4]]])
+ y = jnp.array([[1, jnp.inf]])
+ with pytest.raises(ValueError, match="Input array .+ contains"):
+ mock_regressor._preprocess_fit(X, y)
+
+
+def test_preprocess_fit_higher_dimensional_data_X(mock_regressor):
+ """Test behavior with higher-dimensional input data."""
+ X = jnp.array([[[[1, 2], [3, 4]]]])
+ y = jnp.array([[1, 2]])
+ with pytest.raises(ValueError, match="X must be three-dimensional"):
+ mock_regressor._preprocess_fit(X, y)
+
+
+def test_preprocess_fit_higher_dimensional_data_y(mock_regressor):
+ """Test behavior with higher-dimensional input data."""
+ X = jnp.array([[[[1, 2], [3, 4]]]])
+ y = jnp.array([[[1, 2]]])
+ with pytest.raises(ValueError, match="y must be two-dimensional"):
+ mock_regressor._preprocess_fit(X, y)
+
+
+def test_preprocess_fit_lower_dimensional_data_X(mock_regressor):
+ """Test behavior with lower-dimensional input data."""
+ X = jnp.array([[1, 2], [3, 4]])
+ y = jnp.array([[1, 2]])
+ with pytest.raises(ValueError, match="X must be three-dimensional"):
+ mock_regressor._preprocess_fit(X, y)
+
+
+def test_preprocess_fit_lower_dimensional_data_y(mock_regressor):
+ """Test behavior with lower-dimensional input data."""
+ X = jnp.array([[[[1, 2], [3, 4]]]])
+ y = jnp.array([1, 2])
+ with pytest.raises(ValueError, match="y must be two-dimensional"):
+ mock_regressor._preprocess_fit(X, y)
+
+
+# Preprocess Simulate Tests
+def test_preprocess_simulate_empty_data(mock_regressor):
+ """Test behavior with empty feedforward_input."""
+ feedforward_input = jnp.array([[[]]])
+ params_f = (jnp.array([[]]), jnp.array([]))
+ with pytest.raises(ValueError, match="Model parameters have inconsistent shapes."):
+ mock_regressor._preprocess_simulate(feedforward_input, params_f)
+
+
+def test_preprocess_simulate_invalid_datatypes(mock_regressor):
+ """Test behavior with invalid feedforward_input datatype."""
+ feedforward_input = "invalid_data_type"
+ params_f = (jnp.array([[]]),)
+ with pytest.raises(
+ ValueError,
+ match="could not convert string",
+ ):
+ mock_regressor._preprocess_simulate(feedforward_input, params_f)
+
+
+def test_preprocess_simulate_with_nan(mock_regressor):
+ """Test behavior with NaN values in feedforward_input."""
+ feedforward_input = jnp.array([[[jnp.nan]]])
+ params_f = (jnp.array([[1]]), jnp.array([1]))
+ with pytest.raises(ValueError, match="Input array .+ contains"):
+ mock_regressor._preprocess_simulate(feedforward_input, params_f)
+
+
+def test_preprocess_simulate_with_inf(mock_regressor):
+ """Test behavior with infinite values in feedforward_input."""
+ feedforward_input = jnp.array([[[jnp.inf]]])
+ params_f = (jnp.array([[1]]), jnp.array([1]))
+ with pytest.raises(ValueError, match="Input array .+ contains"):
+ mock_regressor._preprocess_simulate(feedforward_input, params_f)
+
+
+def test_preprocess_simulate_higher_dimensional_data(mock_regressor):
+ """Test behavior with improperly dimensional feedforward_input."""
+ feedforward_input = jnp.array([[[[1]]]])
+ params_f = (jnp.array([[1]]), jnp.array([1]))
+ with pytest.raises(ValueError, match="X must be three-dimensional"):
+ mock_regressor._preprocess_simulate(feedforward_input, params_f)
+
+
+def test_preprocess_simulate_invalid_init_y(mock_regressor):
+ """Test behavior with invalid init_y provided."""
+ feedforward_input = jnp.array([[[1]]])
+ params_f = (jnp.array([[1]]), jnp.array([1]))
+ init_y = jnp.array([[[1]]])
+ params_r = (jnp.array([[1]]),)
+ with pytest.raises(ValueError, match="y must be two-dimensional"):
+ mock_regressor._preprocess_simulate(
+ feedforward_input, params_f, init_y, params_r
+ )
+
+
+def test_preprocess_simulate_feedforward(mock_regressor):
+ """Test that the preprocessing works."""
+ feedforward_input = jnp.array([[[1]]])
+ params_f = (jnp.array([[1]]), jnp.array([1]))
+ (ff,) = mock_regressor._preprocess_simulate(feedforward_input, params_f)
+ assert jnp.all(ff == feedforward_input)
+
+
+def test_empty_set(mock_regressor):
+ """Check that an empty set_params returns self."""
+ assert mock_regressor.set_params() is mock_regressor
+
+
+def test_glm_varargs_error():
+ """Test that variable number of argument in __init__ is not allowed."""
+ bad_estimator = BadEstimator(1)
+ with pytest.raises(
+ RuntimeError, match="GLM estimators should always specify their parameters"
+ ):
+ bad_estimator._get_param_names()
diff --git a/tests/test_basis.py b/tests/test_basis.py
index 2d7e83f3..59f3b2f1 100644
--- a/tests/test_basis.py
+++ b/tests/test_basis.py
@@ -281,7 +281,7 @@ def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs):
with pytest.raises(
ValueError,
match=f"Object class {self.cls.__name__} "
- "requires >= 1 basis elements\.",
+ r"requires >= 1 basis elements\.",
):
self.cls(n_basis_funcs=n_basis_funcs)
else:
@@ -538,7 +538,7 @@ def test_non_empty_samples(self, samples):
self.cls(5, decay_rates=np.arange(1, 6)).evaluate(samples)
@pytest.mark.parametrize(
- "eval_input", [0, [0]*6, (0,)*6, np.array([0]*6), jax.numpy.array([0]*6)]
+ "eval_input", [0, [0] * 6, (0,) * 6, np.array([0] * 6), jax.numpy.array([0] * 6)]
)
def test_evaluate_input(self, eval_input):
"""
diff --git a/tests/test_convolution_1d.py b/tests/test_convolution_1d.py
index 4cfa259d..09959f94 100644
--- a/tests/test_convolution_1d.py
+++ b/tests/test_convolution_1d.py
@@ -13,7 +13,7 @@ def test_basis_matrix_type(self, basis_matrix, trial_count_shape: tuple[int]):
if raise_exception:
with pytest.raises(
ValueError,
- match="Empty basis_matrix provided\. "
+ match=r"Empty basis_matrix provided\. "
r"The shape of basis_matrix is \(0, 0\)!",
):
utils.convolve_1d_trials(basis_matrix, vec)
diff --git a/tests/test_glm.py b/tests/test_glm.py
new file mode 100644
index 00000000..011a2b56
--- /dev/null
+++ b/tests/test_glm.py
@@ -0,0 +1,1116 @@
+from contextlib import nullcontext as does_not_raise
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import pytest
+import statsmodels.api as sm
+from sklearn.model_selection import GridSearchCV
+
+import nemos as nmo
+
+
+class TestGLM:
+ """
+ Unit tests for the PoissonGLM class.
+ """
+
+ #######################
+ # Test model.__init__
+ #######################
+ @pytest.mark.parametrize(
+ "regularizer, expectation",
+ [
+ (nmo.regularizer.Ridge("BFGS"), does_not_raise()),
+ (None, pytest.raises(AttributeError, match="The provided `solver` doesn't implement ")),
+ (nmo.regularizer.Ridge, pytest.raises(TypeError, match="The provided `solver` cannot be instantiated")),
+ ],
+ )
+ def test_solver_type(self, regularizer, expectation, glm_class):
+ """
+ Test that an error is raised if a non-compatible solver is passed.
+ """
+ with expectation:
+ glm_class(regularizer=regularizer)
+
+ @pytest.mark.parametrize(
+ "observation, expectation",
+ [
+ (nmo.observation_models.PoissonObservations(), does_not_raise()),
+ (nmo.regularizer.Regularizer, pytest.raises(AttributeError, match="The provided object does not have the required")),
+ (1, pytest.raises(AttributeError, match="The provided object does not have the required")),
+ ],
+ )
+ def test_init_observation_type(
+ self, observation, expectation, glm_class, ridge_regularizer
+ ):
+ """
+ Test initialization with different regularizer names. Check if an appropriate exception is raised
+ when the regularizer name is not present in jaxopt.
+ """
+ with expectation:
+ glm_class(regularizer=ridge_regularizer, observation_model=observation)
+
+ #######################
+ # Test model.fit
+ #######################
+ @pytest.mark.parametrize(
+ "n_params, expectation",
+ [
+ (0, pytest.raises(ValueError, match="Params needs to be array-like of length two.")),
+ (1, pytest.raises(ValueError, match="Params needs to be array-like of length two.")),
+ (2, does_not_raise()),
+ (3, pytest.raises(ValueError, match="Params needs to be array-like of length two.")),
+ ],
+ )
+ def test_fit_param_length(
+ self, n_params, expectation, poissonGLM_model_instantiation
+ ):
+ """
+ Test the `fit` method with different numbers of initial parameters.
+ Check for correct number of parameters.
+ """
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ if n_params == 0:
+ init_params = tuple()
+ elif n_params == 1:
+ init_params = (true_params[0],)
+ else:
+ init_params = true_params + (true_params[0],) * (n_params - 2)
+ with expectation:
+ model.fit(X, y, init_params=init_params)
+
+ @pytest.mark.parametrize(
+ "add_entry, add_to, expectation",
+ [
+ (0, "X", does_not_raise()),
+ (np.nan, "X", pytest.raises(ValueError, match="Input array .+ contains")),
+ (np.inf, "X", pytest.raises(ValueError, match="Input array .+ contains")),
+ (0, "y", does_not_raise()),
+ (np.nan, "y", pytest.raises(ValueError, match="Input array .+ contains")),
+ (np.inf, "y", pytest.raises(ValueError, match="Input array .+ contains")),
+ ],
+ )
+ def test_fit_param_values(
+ self, add_entry, add_to, expectation, poissonGLM_model_instantiation
+ ):
+ """
+ Test the `fit` method with altered X or y values. Ensure the method raises exceptions for NaN or Inf values.
+ """
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ if add_to == "X":
+ # get an index to be edited
+ idx = np.unravel_index(np.random.choice(X.size), X.shape)
+ X[idx] = add_entry
+ elif add_to == "y":
+ idx = np.unravel_index(np.random.choice(y.size), y.shape)
+ y = np.asarray(y, dtype=np.float32)
+ y[idx] = add_entry
+ with expectation:
+ model.fit(X, y, init_params=true_params)
+
+ @pytest.mark.parametrize(
+ "dim_weights, expectation",
+ [
+ (0, pytest.raises(ValueError, match=r"params\[0\] must be of shape \(n_neurons, n_features\)")),
+ (1, pytest.raises(ValueError, match=r"params\[0\] must be of shape \(n_neurons, n_features\)")),
+ (2, does_not_raise()),
+ (3, pytest.raises(ValueError, match=r"params\[0\] must be of shape \(n_neurons, n_features\)")),
+ ],
+ )
+ def test_fit_weights_dimensionality(
+ self, dim_weights, expectation, poissonGLM_model_instantiation
+ ):
+ """
+ Test the `fit` method with weight matrices of different dimensionalities.
+ Check for correct dimensionality.
+ """
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ n_samples, n_neurons, n_features = X.shape
+ if dim_weights == 0:
+ init_w = jnp.array([])
+ elif dim_weights == 1:
+ init_w = jnp.zeros((n_neurons,))
+ elif dim_weights == 2:
+ init_w = jnp.zeros((n_neurons, n_features))
+ else:
+ init_w = jnp.zeros((n_neurons, n_features) + (1,) * (dim_weights - 2))
+ with expectation:
+ model.fit(X, y, init_params=(init_w, true_params[1]))
+
+ @pytest.mark.parametrize(
+ "dim_intercepts, expectation",
+ [
+ (0, pytest.raises(ValueError, match=r"params\[1\] must be of shape")),
+ (1, does_not_raise()),
+ (2, pytest.raises(ValueError, match=r"params\[1\] must be of shape")),
+ (3, pytest.raises(ValueError, match=r"params\[1\] must be of shape")),
+ ],
+ )
+ def test_fit_intercepts_dimensionality(
+ self, dim_intercepts, expectation, poissonGLM_model_instantiation
+ ):
+ """
+ Test the `fit` method with intercepts of different dimensionalities. Check for correct dimensionality.
+ """
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ n_samples, n_neurons, n_features = X.shape
+ init_b = jnp.zeros((n_neurons,) * dim_intercepts)
+ init_w = jnp.zeros((n_neurons, n_features))
+ with expectation:
+ model.fit(X, y, init_params=(init_w, init_b))
+
+ @pytest.mark.parametrize(
+ "init_params, expectation",
+ [
+ ([jnp.zeros((1, 5)), jnp.zeros((1,))], does_not_raise()),
+ (iter([jnp.zeros((1, 5)), jnp.zeros((1,))]), does_not_raise()),
+ (dict(p1=jnp.zeros((1, 5)), p2=jnp.zeros((1,))), pytest.raises(TypeError, match="Initial parameters must be array-like")),
+ (0, pytest.raises(TypeError, match="Initial parameters must be array-like")),
+ ({0, 1}, pytest.raises(ValueError, match=r"params\[0\] must be of shape")),
+ ([jnp.zeros((1, 5)), ""], pytest.raises(TypeError, match="Initial parameters must be array-like")),
+ (["", jnp.zeros((1,))], pytest.raises(TypeError, match="Initial parameters must be array-like")),
+ ],
+ )
+ def test_fit_init_params_type(
+ self, init_params, expectation, poissonGLM_model_instantiation
+ ):
+ """
+ Test the `fit` method with various types of initial parameters. Ensure that the provided initial parameters
+ are array-like.
+ """
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ with expectation:
+ model.fit(X, y, init_params=init_params)
+
+ @pytest.mark.parametrize(
+ "delta_n_neuron, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="Model parameters have inconsistent shapes")),
+ (0, does_not_raise()),
+ (1, pytest.raises(ValueError, match="Model parameters have inconsistent shapes")),
+ ],
+ )
+ def test_fit_n_neuron_match_baseline_rate(
+ self, delta_n_neuron, expectation, poissonGLM_model_instantiation
+ ):
+ """
+ Test the `fit` method ensuring The number of neurons in the baseline rate matches the expected number.
+ """
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ n_samples, n_neurons, n_features = X.shape
+ init_b = jnp.zeros((n_neurons + delta_n_neuron,))
+ with expectation:
+ model.fit(X, y, init_params=(true_params[0], init_b))
+
+ @pytest.mark.parametrize(
+ "delta_n_neuron, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="The number of neurons in the model parameters")),
+ (0, does_not_raise()),
+ (1, pytest.raises(ValueError, match="The number of neurons in the model parameters")),
+ ],
+ )
+ def test_fit_n_neuron_match_x(
+ self, delta_n_neuron, expectation, poissonGLM_model_instantiation
+ ):
+ """
+ Test the `fit` method ensuring The number of neurons in X matches The number of neurons in the model.
+ """
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ n_neurons = X.shape[1]
+ X = jnp.repeat(X, n_neurons + delta_n_neuron, axis=1)
+ with expectation:
+ model.fit(X, y, init_params=true_params)
+
+ @pytest.mark.parametrize(
+ "delta_n_neuron, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="The number of neurons in the model parameters")),
+ (0, does_not_raise()),
+ (1, pytest.raises(ValueError, match="The number of neurons in the model parameters")),
+ ],
+ )
+ def test_fit_n_neuron_match_y(
+ self, delta_n_neuron, expectation, poissonGLM_model_instantiation
+ ):
+ """
+ Test the `fit` method ensuring The number of neurons in y matches The number of neurons in the model.
+ """
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ n_neurons = X.shape[1]
+ y = jnp.repeat(y, n_neurons + delta_n_neuron, axis=1)
+ with expectation:
+ model.fit(X, y, init_params=true_params)
+
+ @pytest.mark.parametrize(
+ "delta_dim, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="X must be three-dimensional")),
+ (0, does_not_raise()),
+ (1, pytest.raises(ValueError, match="X must be three-dimensional")),
+ ],
+ )
+ def test_fit_x_dimensionality(
+ self, delta_dim, expectation, poissonGLM_model_instantiation
+ ):
+ """
+ Test the `fit` method with X input data of different dimensionalities. Ensure correct dimensionality for X.
+ """
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ if delta_dim == -1:
+ X = np.zeros((X.shape[0], X.shape[1]))
+ elif delta_dim == 1:
+ X = np.zeros((X.shape[0], X.shape[1], X.shape[2], 1))
+ with expectation:
+ model.fit(X, y, init_params=true_params)
+
+ @pytest.mark.parametrize(
+ "delta_dim, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="y must be two-dimensional")),
+ (0, does_not_raise()),
+ (1, pytest.raises(ValueError, match="y must be two-dimensional")),
+ ],
+ )
+ def test_fit_y_dimensionality(
+ self, delta_dim, expectation, poissonGLM_model_instantiation
+ ):
+ """
+ Test the `fit` method with y target data of different dimensionalities. Ensure correct dimensionality for y.
+ """
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ if delta_dim == -1:
+ y = np.zeros(X.shape[0])
+ elif delta_dim == 1:
+ y = np.zeros((X.shape[0], X.shape[1], 1))
+ with expectation:
+ model.fit(X, y, init_params=true_params)
+
+ @pytest.mark.parametrize(
+ "delta_n_features, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="Inconsistent number of features")),
+ (0, does_not_raise()),
+ (1, pytest.raises(ValueError, match="Inconsistent number of features")),
+ ],
+ )
+ def test_fit_n_feature_consistency_weights(
+ self, delta_n_features, expectation, poissonGLM_model_instantiation
+ ):
+ """
+ Test the `fit` method for inconsistencies between data features and initial weights provided.
+ Ensure the number of features align.
+ """
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ init_w = jnp.zeros((X.shape[1], X.shape[2] + delta_n_features))
+ init_b = jnp.zeros(X.shape[1])
+ with expectation:
+ model.fit(X, y, init_params=(init_w, init_b))
+
+ @pytest.mark.parametrize(
+ "delta_n_features, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="Inconsistent number of features")),
+ (0, does_not_raise()),
+ (1, pytest.raises(ValueError, match="Inconsistent number of features")),
+ ],
+ )
+ def test_fit_n_feature_consistency_x(
+ self, delta_n_features, expectation, poissonGLM_model_instantiation
+ ):
+ """
+ Test the `fit` method for inconsistencies between data features and model's expectations.
+ Ensure the number of features in X aligns.
+ """
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ if delta_n_features == 1:
+ X = jnp.concatenate((X, jnp.zeros((X.shape[0], X.shape[1], 1))), axis=2)
+ elif delta_n_features == -1:
+ X = X[..., :-1]
+ with expectation:
+ model.fit(X, y, init_params=true_params)
+
+ @pytest.mark.parametrize(
+ "delta_tp, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="The number of time-points in X and y")),
+ (0, does_not_raise()),
+ (1, pytest.raises(ValueError, match="The number of time-points in X and y")),
+ ],
+ )
+ def test_fit_time_points_x(
+ self, delta_tp, expectation, poissonGLM_model_instantiation
+ ):
+ """
+ Test the `fit` method for inconsistencies in time-points in data X. Ensure the correct number of time-points.
+ """
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ X = jnp.zeros((X.shape[0] + delta_tp,) + X.shape[1:])
+ with expectation:
+ model.fit(X, y, init_params=true_params)
+
+ @pytest.mark.parametrize(
+ "delta_tp, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="The number of time-points in X and y")),
+ (0, does_not_raise()),
+ (1, pytest.raises(ValueError, match="The number of time-points in X and y")),
+ ],
+ )
+ def test_fit_time_points_y(
+ self, delta_tp, expectation, poissonGLM_model_instantiation
+ ):
+ """
+ Test the `fit` method for inconsistencies in time-points in y. Ensure the correct number of time-points.
+ """
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ y = jnp.zeros((y.shape[0] + delta_tp,) + y.shape[1:])
+ with expectation:
+ model.fit(X, y, init_params=true_params)
+
+ def test_fit_mask_grouplasso(self, group_sparse_poisson_glm_model_instantiation):
+ """Test that the group lasso fit goes through"""
+ X, y, model, params, rate, mask = group_sparse_poisson_glm_model_instantiation
+ model.set_params(
+ regularizer=nmo.regularizer.GroupLasso(
+ solver_name="ProximalGradient", mask=mask
+ )
+ )
+ model.fit(X, y)
+
+ #######################
+ # Test model.score
+ #######################
+ @pytest.mark.parametrize(
+ "delta_n_neuron, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="The number of neurons in the model parameters")),
+ (0, does_not_raise()),
+ (1, pytest.raises(ValueError, match="The number of neurons in the model parameters")),
+ ],
+ )
+ def test_score_n_neuron_match_x(
+ self, delta_n_neuron, expectation, poissonGLM_model_instantiation
+ ):
+ """
+ Test the `score` method when The number of neurons in X differs. Ensure the correct number of neurons.
+ """
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ model.coef_ = true_params[0]
+ model.intercept_ = true_params[1]
+ X = jnp.repeat(X, X.shape[1] + delta_n_neuron, axis=1)
+ with expectation:
+ model.score(X, y)
+
+ @pytest.mark.parametrize(
+ "delta_n_neuron, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="The number of neurons in the model parameters")),
+ (0, does_not_raise()),
+ (1, pytest.raises(ValueError, match="The number of neurons in the model parameters")),
+ ],
+ )
+ def test_score_n_neuron_match_y(
+ self, delta_n_neuron, expectation, poissonGLM_model_instantiation
+ ):
+ """
+ Test the `score` method when The number of neurons in y differs. Ensure the correct number of neurons.
+ """
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ model.coef_ = true_params[0]
+ model.intercept_ = true_params[1]
+ y = jnp.repeat(y, y.shape[1] + delta_n_neuron, axis=1)
+ with expectation:
+ model.score(X, y)
+
+ @pytest.mark.parametrize(
+ "delta_dim, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="X must be three-dimensional")),
+ (0, does_not_raise()),
+ (1, pytest.raises(ValueError, match="X must be three-dimensional")),
+ ],
+ )
+ def test_score_x_dimensionality(
+ self, delta_dim, expectation, poissonGLM_model_instantiation
+ ):
+ """
+ Test the `score` method with X input data of different dimensionalities. Ensure correct dimensionality for X.
+ """
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ model.coef_ = true_params[0]
+ model.intercept_ = true_params[1]
+ if delta_dim == -1:
+ X = np.zeros((X.shape[0], X.shape[1]))
+ elif delta_dim == 1:
+ X = np.zeros((X.shape[0], X.shape[1], X.shape[2], 1))
+ with expectation:
+ model.score(X, y)
+
+ @pytest.mark.parametrize(
+ "delta_dim, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="y must be two-dimensional, with shape")),
+ (0, does_not_raise()),
+ (1, pytest.raises(ValueError, match="y must be two-dimensional, with shape")),
+ ],
+ )
+ def test_score_y_dimensionality(
+ self, delta_dim, expectation, poissonGLM_model_instantiation
+ ):
+ """
+ Test the `score` method with y of different dimensionalities.
+ Ensure correct dimensionality for y.
+ """
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ model.coef_ = true_params[0]
+ model.intercept_ = true_params[1]
+ if delta_dim == -1:
+ y = np.zeros((X.shape[0],))
+ elif delta_dim == 1:
+ y = np.zeros((X.shape[0], X.shape[1], 1))
+ with expectation:
+ model.score(X, y)
+
+ @pytest.mark.parametrize(
+ "delta_n_features, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="Inconsistent number of features")),
+ (0, does_not_raise()),
+ (1, pytest.raises(ValueError, match="Inconsistent number of features")),
+ ],
+ )
+ def test_score_n_feature_consistency_x(
+ self, delta_n_features, expectation, poissonGLM_model_instantiation
+ ):
+ """
+ Test the `score` method for inconsistencies in features of X.
+ Ensure the number of features in X aligns with the model params.
+ """
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ model.coef_ = true_params[0]
+ model.intercept_ = true_params[1]
+ if delta_n_features == 1:
+ X = jnp.concatenate((X, jnp.zeros((X.shape[0], X.shape[1], 1))), axis=2)
+ elif delta_n_features == -1:
+ X = X[..., :-1]
+ with expectation:
+ model.score(X, y)
+
+ @pytest.mark.parametrize(
+ "is_fit, expectation",
+ [
+ (True, does_not_raise()),
+ (False, pytest.raises(ValueError, match="This GLM instance is not fitted yet")),
+ ],
+ )
+ def test_predict_is_fit(self, is_fit, expectation, poissonGLM_model_instantiation):
+ """
+ Test the `score` method on models based on their fit status.
+ Ensure scoring is only possible on fitted models.
+ """
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ if is_fit:
+ model.fit(X, y)
+ with expectation:
+ model.predict(X)
+
+ @pytest.mark.parametrize(
+ "delta_tp, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="The number of time-points in X and y")),
+ (0, does_not_raise()),
+ (1, pytest.raises(ValueError, match="The number of time-points in X and y")),
+ ],
+ )
+ def test_score_time_points_x(
+ self, delta_tp, expectation, poissonGLM_model_instantiation
+ ):
+ """
+ Test the `score` method for inconsistencies in time-points in X.
+ Ensure that the number of time-points in X and y matches.
+ """
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ model.coef_ = true_params[0]
+ model.intercept_ = true_params[1]
+ X = jnp.zeros((X.shape[0] + delta_tp,) + X.shape[1:])
+ with expectation:
+ model.score(X, y)
+
+ @pytest.mark.parametrize(
+ "delta_tp, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="The number of time-points in X and y")),
+ (0, does_not_raise()),
+ (1, pytest.raises(ValueError, match="The number of time-points in X and y")),
+ ],
+ )
+ def test_score_time_points_y(
+ self, delta_tp, expectation, poissonGLM_model_instantiation
+ ):
+ """
+ Test the `score` method for inconsistencies in time-points in y.
+ Ensure that the number of time-points in X and y matches.
+ """
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ model.coef_ = true_params[0]
+ model.intercept_ = true_params[1]
+ y = jnp.zeros((y.shape[0] + delta_tp,) + y.shape[1:])
+ with expectation:
+ model.score(X, y)
+
+ @pytest.mark.parametrize(
+ "score_type, expectation",
+ [
+ ("pseudo-r2-McFadden", does_not_raise()),
+ ("pseudo-r2-Cohen", does_not_raise()),
+ ("log-likelihood", does_not_raise()),
+ ("not-implemented", pytest.raises(NotImplementedError, match="Scoring method not-implemented not implemented")),
+ ],
+ )
+ def test_score_type_r2(
+ self, score_type, expectation, poissonGLM_model_instantiation
+ ):
+ """
+ Test the `score` method for unsupported scoring types.
+ Ensure only valid score types are used.
+ """
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ model.coef_ = true_params[0]
+ model.intercept_ = true_params[1]
+ with expectation:
+ model.score(X, y, score_type=score_type)
+
+ def test_loglikelihood_against_scipy_stats(self, poissonGLM_model_instantiation):
+ """
+ Compare the model's log-likelihood computation against `jax.scipy`.
+ Ensure consistent and correct calculations.
+ """
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ # set model coeff
+ model.coef_ = true_params[0]
+ model.intercept_ = true_params[1]
+ # get the rate
+ mean_firing = model.predict(X)
+ # compute the log-likelihood using jax.scipy
+ mean_ll_jax = jax.scipy.stats.poisson.logpmf(y, mean_firing).mean()
+ model_ll = model.score(X, y, score_type="log-likelihood")
+ if not np.allclose(mean_ll_jax, model_ll):
+ raise ValueError(
+ "Log-likelihood of PoissonModel does not match" "that of jax.scipy!"
+ )
+
+ #######################
+ # Test model.predict
+ #######################
+ @pytest.mark.parametrize(
+ "delta_n_neuron, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="The number of neurons in the model parameters")),
+ (0, does_not_raise()),
+ (1, pytest.raises(ValueError, match="The number of neurons in the model parameters")),
+ ],
+ )
+ def test_predict_n_neuron_match_x(
+ self, delta_n_neuron, expectation, poissonGLM_model_instantiation
+ ):
+ """
+ Test the `predict` method when The number of neurons in X differs.
+ Ensure that The number of neurons in X, y and params matches.
+ """
+ X, _, model, true_params, _ = poissonGLM_model_instantiation
+ model.coef_ = true_params[0]
+ model.intercept_ = true_params[1]
+ X = jnp.repeat(X, X.shape[1] + delta_n_neuron, axis=1)
+ with expectation:
+ model.predict(X)
+
+ @pytest.mark.parametrize(
+ "delta_dim, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="X must be three-dimensional")),
+ (0, does_not_raise()),
+ (1, pytest.raises(ValueError, match="X must be three-dimensional")),
+ ],
+ )
+ def test_predict_x_dimensionality(
+ self, delta_dim, expectation, poissonGLM_model_instantiation
+ ):
+ """
+ Test the `predict` method with x input data of different dimensionalities.
+ Ensure correct dimensionality for x.
+ """
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ model.coef_ = true_params[0]
+ model.intercept_ = true_params[1]
+ if delta_dim == -1:
+ X = np.zeros((X.shape[0], X.shape[1]))
+ elif delta_dim == 1:
+ X = np.zeros((X.shape[0], X.shape[1], X.shape[2], 1))
+ with expectation:
+ model.predict(X)
+
+ @pytest.mark.parametrize(
+ "delta_n_features, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="Inconsistent number of features")),
+ (0, does_not_raise()),
+ (1, pytest.raises(ValueError, match="Inconsistent number of features")),
+ ],
+ )
+ def test_predict_n_feature_consistency_x(
+ self, delta_n_features, expectation, poissonGLM_model_instantiation
+ ):
+ """
+ Test the `predict` method ensuring the number of features in x input data
+ is consistent with the model's `model.coef_`.
+ """
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ model.coef_ = true_params[0]
+ model.intercept_ = true_params[1]
+ if delta_n_features == 1:
+ X = jnp.concatenate((X, jnp.zeros((X.shape[0], X.shape[1], 1))), axis=2)
+ elif delta_n_features == -1:
+ X = X[..., :-1]
+ with expectation:
+ model.predict(X)
+
+ @pytest.mark.parametrize(
+ "is_fit, expectation",
+ [
+ (True, does_not_raise()),
+ (False, pytest.raises(ValueError, match="This GLM instance is not fitted yet")),
+ ],
+ )
+ def test_predict_is_fit(self, is_fit, expectation, poissonGLM_model_instantiation):
+ """
+ Test the `score` method on models based on their fit status.
+ Ensure scoring is only possible on fitted models.
+ """
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ if is_fit:
+ model.fit(X, y)
+ with expectation:
+ model.predict(X)
+
+ #######################
+ # Test model.simulate
+ #######################
+ @pytest.mark.parametrize(
+ "delta_n_neuron, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="The number of neurons in the model parameters")),
+ (0, does_not_raise()),
+ (1, pytest.raises(ValueError, match="The number of neurons in the model parameters")),
+ ],
+ )
+ def test_simulate_n_neuron_match_input(
+ self, delta_n_neuron, expectation, poissonGLM_coupled_model_config_simulate
+ ):
+ """
+ Test the `simulate` method to ensure that The number of neurons in the input
+ matches the model's parameters.
+ """
+ (
+ model,
+ coupling_basis,
+ feedforward_input,
+ init_spikes,
+ random_key,
+ ) = poissonGLM_coupled_model_config_simulate
+ if delta_n_neuron != 0:
+ feedforward_input = np.zeros(
+ (
+ feedforward_input.shape[0],
+ feedforward_input.shape[1] + delta_n_neuron,
+ feedforward_input.shape[2],
+ )
+ )
+ with expectation:
+ model.simulate_recurrent(
+ random_key=random_key,
+ init_y=init_spikes,
+ coupling_basis_matrix=coupling_basis,
+ feedforward_input=feedforward_input,
+ )
+
+ @pytest.mark.parametrize(
+ "delta_dim, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="X must be three-dimensional")),
+ (0, does_not_raise()),
+ (1, pytest.raises(ValueError, match="X must be three-dimensional")),
+ ],
+ )
+ def test_simulate_input_dimensionality(
+ self, delta_dim, expectation, poissonGLM_coupled_model_config_simulate
+ ):
+ """
+ Test the `simulate` method with input data of different dimensionalities.
+ Ensure correct dimensionality for input.
+ """
+ (
+ model,
+ coupling_basis,
+ feedforward_input,
+ init_spikes,
+ random_key,
+ ) = poissonGLM_coupled_model_config_simulate
+ if delta_dim == -1:
+ feedforward_input = np.zeros(feedforward_input.shape[:2])
+ elif delta_dim == 1:
+ feedforward_input = np.zeros(feedforward_input.shape + (1,))
+ with expectation:
+ model.simulate_recurrent(
+ random_key=random_key,
+ init_y=init_spikes,
+ coupling_basis_matrix=coupling_basis,
+ feedforward_input=feedforward_input,
+ )
+
+ @pytest.mark.parametrize(
+ "delta_dim, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="y must be two-dimensional")),
+ (0, does_not_raise()),
+ (1, pytest.raises(ValueError, match="y must be two-dimensional")),
+ ],
+ )
+ def test_simulate_y_dimensionality(
+ self, delta_dim, expectation, poissonGLM_coupled_model_config_simulate
+ ):
+ """
+ Test the `simulate` method with init_spikes of different dimensionalities.
+ Ensure correct dimensionality for init_spikes.
+ """
+ (
+ model,
+ coupling_basis,
+ feedforward_input,
+ init_spikes,
+ random_key,
+ ) = poissonGLM_coupled_model_config_simulate
+ if delta_dim == -1:
+ init_spikes = np.zeros((feedforward_input.shape[0],))
+ elif delta_dim == 1:
+ init_spikes = np.zeros(
+ (feedforward_input.shape[0], feedforward_input.shape[1], 1)
+ )
+ with expectation:
+ model.simulate_recurrent(
+ random_key=random_key,
+ init_y=init_spikes,
+ coupling_basis_matrix=coupling_basis,
+ feedforward_input=feedforward_input,
+ )
+
+ @pytest.mark.parametrize(
+ "delta_n_neuron, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="The number of neurons in the model parameters")),
+ (0, does_not_raise()),
+ (1, pytest.raises(ValueError, match="The number of neurons in the model parameters")),
+ ],
+ )
+ def test_simulate_n_neuron_match_y(
+ self, delta_n_neuron, expectation, poissonGLM_coupled_model_config_simulate
+ ):
+ """
+ Test the `simulate` method to ensure that The number of neurons in init_spikes
+ matches the model's parameters.
+ """
+ (
+ model,
+ coupling_basis,
+ feedforward_input,
+ init_spikes,
+ random_key,
+ ) = poissonGLM_coupled_model_config_simulate
+ init_spikes = jnp.zeros(
+ (init_spikes.shape[0], feedforward_input.shape[1] + delta_n_neuron)
+ )
+ with expectation:
+ model.simulate_recurrent(
+ random_key=random_key,
+ init_y=init_spikes,
+ coupling_basis_matrix=coupling_basis,
+ feedforward_input=feedforward_input,
+ )
+
+ @pytest.mark.parametrize(
+ "is_fit, expectation",
+ [
+ (True, does_not_raise()),
+ (False, pytest.raises(ValueError, match="This GLM instance is not fitted yet")),
+ ],
+ )
+ def test_simulate_is_fit(
+ self, is_fit, expectation, poissonGLM_coupled_model_config_simulate
+ ):
+ """
+ Test if the model raises a ValueError when trying to simulate before it's fitted.
+ """
+ (
+ model,
+ coupling_basis,
+ feedforward_input,
+ init_spikes,
+ random_key,
+ ) = poissonGLM_coupled_model_config_simulate
+ if not is_fit:
+ model.intercept_ = None
+ with expectation:
+ model.simulate_recurrent(
+ random_key=random_key,
+ init_y=init_spikes,
+ coupling_basis_matrix=coupling_basis,
+ feedforward_input=feedforward_input,
+ )
+
+ @pytest.mark.parametrize(
+ "delta_tp, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="`init_y` and `coupling_basis_matrix`")),
+ (0, does_not_raise()),
+ (1, pytest.raises(ValueError, match="`init_y` and `coupling_basis_matrix`")),
+ ],
+ )
+ def test_simulate_time_point_match_y(
+ self, delta_tp, expectation, poissonGLM_coupled_model_config_simulate
+ ):
+ """
+ Test the `simulate` method to ensure that the time points in init_y
+ are consistent with the coupling_basis window size (they must be equal).
+ """
+ (
+ model,
+ coupling_basis,
+ feedforward_input,
+ init_spikes,
+ random_key,
+ ) = poissonGLM_coupled_model_config_simulate
+ init_spikes = jnp.zeros((init_spikes.shape[0] + delta_tp, init_spikes.shape[1]))
+ with expectation:
+ model.simulate_recurrent(
+ random_key=random_key,
+ init_y=init_spikes,
+ coupling_basis_matrix=coupling_basis,
+ feedforward_input=feedforward_input,
+ )
+
+ @pytest.mark.parametrize(
+ "delta_tp, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="`init_y` and `coupling_basis_matrix`")),
+ (0, does_not_raise()),
+ (1, pytest.raises(ValueError, match="`init_y` and `coupling_basis_matrix`")),
+ ],
+ )
+ def test_simulate_time_point_match_coupling_basis(
+ self, delta_tp, expectation, poissonGLM_coupled_model_config_simulate
+ ):
+ """
+ Test the `simulate` method to ensure that the window size in coupling_basis
+ is consistent with the time-points in init_spikes (they must be equal).
+ """
+ (
+ model,
+ coupling_basis,
+ feedforward_input,
+ init_spikes,
+ random_key,
+ ) = poissonGLM_coupled_model_config_simulate
+ coupling_basis = jnp.zeros(
+ (coupling_basis.shape[0] + delta_tp,) + coupling_basis.shape[1:]
+ )
+ with expectation:
+ model.simulate_recurrent(
+ random_key=random_key,
+ init_y=init_spikes,
+ coupling_basis_matrix=coupling_basis,
+ feedforward_input=feedforward_input,
+ )
+
+ @pytest.mark.parametrize(
+ "delta_features, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="Inconsistent number of features. spike basis coefficients has")),
+ (0, does_not_raise()),
+ (1, pytest.raises(ValueError, match="Inconsistent number of features. spike basis coefficients has")),
+ ],
+ )
+ def test_simulate_feature_consistency_input(
+ self, delta_features, expectation, poissonGLM_coupled_model_config_simulate
+ ):
+ """
+ Test the `simulate` method ensuring the number of features in `feedforward_input` is
+ consistent with the model's expected number of features.
+
+ Notes
+ -----
+ The total feature number `model.coef_.shape[1]` must be equal to
+ `feedforward_input.shape[2] + coupling_basis.shape[1]*n_neurons`
+ """
+ (
+ model,
+ coupling_basis,
+ feedforward_input,
+ init_spikes,
+ random_key,
+ ) = poissonGLM_coupled_model_config_simulate
+ feedforward_input = jnp.zeros(
+ (
+ feedforward_input.shape[0],
+ feedforward_input.shape[1],
+ feedforward_input.shape[2] + delta_features,
+ )
+ )
+ with expectation:
+ model.simulate_recurrent(
+ random_key=random_key,
+ init_y=init_spikes,
+ coupling_basis_matrix=coupling_basis,
+ feedforward_input=feedforward_input,
+ )
+
+ @pytest.mark.parametrize(
+ "delta_features, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="Inconsistent number of features")),
+ (0, does_not_raise()),
+ (1, pytest.raises(ValueError, match="Inconsistent number of features")),
+ ],
+ )
+ def test_simulate_feature_consistency_coupling_basis(
+ self, delta_features, expectation, poissonGLM_coupled_model_config_simulate
+ ):
+ """
+ Test the `simulate` method ensuring the number of features in `coupling_basis` is
+ consistent with the model's expected number of features.
+
+ Notes
+ -----
+ The total feature number `model.coef_.shape[1]` must be equal to
+ `feedforward_input.shape[2] + coupling_basis.shape[1]*n_neurons`
+ """
+ (
+ model,
+ coupling_basis,
+ feedforward_input,
+ init_spikes,
+ random_key,
+ ) = poissonGLM_coupled_model_config_simulate
+ coupling_basis = jnp.zeros(
+ (coupling_basis.shape[0], coupling_basis.shape[1] + delta_features)
+ )
+ with expectation:
+ model.simulate_recurrent(
+ random_key=random_key,
+ init_y=init_spikes,
+ coupling_basis_matrix=coupling_basis,
+ feedforward_input=feedforward_input,
+ )
+
+ def test_simulate_feedforward_GLM_not_fit(self, poissonGLM_model_instantiation):
+ X, y, model, params, rate = poissonGLM_model_instantiation
+ with pytest.raises(
+ nmo.exceptions.NotFittedError, match="This GLM instance is not fitted yet"
+ ):
+ model.simulate(jax.random.PRNGKey(123), X)
+
+ def test_simulate_feedforward_GLM(self, poissonGLM_model_instantiation):
+ """Test that simulate goes through"""
+ X, y, model, params, rate = poissonGLM_model_instantiation
+ model.coef_ = params[0]
+ model.intercept_ = params[1]
+ ysim, ratesim = model.simulate(jax.random.PRNGKey(123), X)
+ # check that the expected dimensionality is returned
+ assert ysim.ndim == 2
+ assert ratesim.ndim == 2
+ # check that the rates and spikes has the same shape
+ assert ratesim.shape[0] == ysim.shape[0]
+ assert ratesim.shape[1] == ysim.shape[1]
+ # check the time point number is that expected (same as the input)
+ assert ysim.shape[0] == X.shape[0]
+ # check that the number if neurons is respected
+ assert ysim.shape[1] == y.shape[1]
+
+ #######################################
+ # Compare with standard implementation
+ #######################################
+ def test_deviance_against_statsmodels(self, poissonGLM_model_instantiation):
+ """
+ Compare fitted parameters to statsmodels.
+ Assesses if the model estimates are close to statsmodels' results.
+ """
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ # set model coeff
+ model.coef_ = true_params[0]
+ model.intercept_ = true_params[1]
+ # get the rate
+ dev = sm.families.Poisson().deviance(y, firing_rate)
+ dev_model = model.observation_model.deviance(firing_rate, y).sum()
+ if not np.allclose(dev, dev_model):
+ raise ValueError("Deviance doesn't match statsmodels!")
+
+ def test_compatibility_with_sklearn_cv(self, poissonGLM_model_instantiation):
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ param_grid = {"regularizer__solver_name": ["BFGS", "GradientDescent"]}
+ GridSearchCV(model, param_grid).fit(X, y)
+
+ def test_end_to_end_fit_and_simulate(
+ self, poissonGLM_coupled_model_config_simulate
+ ):
+ (
+ model,
+ coupling_basis,
+ feedforward_input,
+ init_spikes,
+ random_key,
+ ) = poissonGLM_coupled_model_config_simulate
+ window_size = coupling_basis.shape[0]
+ n_neurons = init_spikes.shape[1]
+ n_trials = 1
+ n_timepoints = feedforward_input.shape[0]
+
+ # generate spike trains
+ spikes, _ = model.simulate_recurrent(
+ random_key=random_key,
+ init_y=init_spikes,
+ coupling_basis_matrix=coupling_basis,
+ feedforward_input=feedforward_input,
+ )
+
+ # convolve basis and spikes
+ # (n_trials, n_timepoints - ws + 1, n_neurons, n_coupling_basis)
+ conv_spikes = jnp.asarray(
+ nmo.utils.convolve_1d_trials(coupling_basis, [spikes]), dtype=jnp.float32
+ )
+
+ # create an individual neuron predictor by stacking the
+ # two convolved spike trains in a single feature vector
+ # and concatenate the trials.
+ conv_spikes = conv_spikes.reshape(
+ n_trials * (n_timepoints - window_size + 1), -1
+ )
+
+ # replicate for each neuron,
+ # (n_trials * (n_timepoints - ws + 1), n_neurons, n_neurons * n_coupling_basis)
+ conv_spikes = jnp.tile(conv_spikes, n_neurons).reshape(
+ conv_spikes.shape[0], n_neurons, conv_spikes.shape[1]
+ )
+
+ # add the feed-forward input to the predictors
+ X = jnp.concatenate((conv_spikes[1:], feedforward_input[:-window_size]), axis=2)
+
+ # fit the model
+ model.fit(X, spikes[:-window_size])
+
+ # simulate
+ model.simulate_recurrent(
+ random_key=random_key,
+ init_y=init_spikes,
+ coupling_basis_matrix=coupling_basis,
+ feedforward_input=feedforward_input,
+ )
diff --git a/tests/test_glm_runs.py b/tests/test_glm_runs.py
deleted file mode 100644
index df08f3af..00000000
--- a/tests/test_glm_runs.py
+++ /dev/null
@@ -1,65 +0,0 @@
-import jax
-import numpy as np
-
-from nemos.basis import MSplineBasis
-from nemos.glm import GLM
-
-
-class DimensionMismatchError(Exception):
- """Exception raised for dimension mismatch errors."""
-
- def __init__(self, message):
- self.message = message
- super().__init__(self.message)
-
-
-def test_setup_msplinebasis():
- """
- Minimal test for MSplineBasis definition.
-
- Returns
- -------
- None
-
- Raises
- ------
- DimensionMismatchError
- If the output basis matrix has mismatched dimensions with the specified basis functions or window size.
-
- Notes
- -----
- This function performs a minimal test for defining the MBasis by generating basis functions using different orders.
- It checks if the output basis matrix has dimensions that match the specified number of basis functions and window size.
- """
- n_basis = 6
- window = 100
- for order in range(1, 6):
- spike_basis = MSplineBasis(n_basis_funcs=n_basis, order=order)
- spike_basis_matrix = spike_basis.evaluate(np.arange(window)).T
- if spike_basis_matrix.shape[0] != n_basis:
- raise DimensionMismatchError(
- f"The output basis matrix has {spike_basis_matrix.shape[1]} time points, while the number of basis specified is {n_basis}. They must agree."
- )
-
- if spike_basis_matrix.shape[1] != window:
- raise DimensionMismatchError(
- f"The output basis basis matrix has {spike_basis_matrix.shape[1]} window size, while the window size specified is {window}. They must agree."
- )
-
-
-def test_run_end_to_end_glm():
- nn, nt = 10, 1000
- key = jax.random.PRNGKey(123)
- key, subkey = jax.random.split(key)
- spike_data = jax.random.bernoulli(subkey, jax.numpy.ones((nn, nt)) * 0.5).astype(
- "int64"
- )
-
- spike_basis = MSplineBasis(n_basis_funcs=6, order=3)
- spike_basis_matrix = spike_basis.evaluate(np.arange(100)).T
- model = GLM(spike_basis_matrix)
-
- model.fit(spike_data)
- model.predict(spike_data)
- key, subkey = jax.random.split(key)
- X = model.simulate(subkey, 20, spike_data[:, :100])
diff --git a/tests/test_glm_synthetic.py b/tests/test_glm_synthetic.py
deleted file mode 100644
index dcf78e1b..00000000
--- a/tests/test_glm_synthetic.py
+++ /dev/null
@@ -1,96 +0,0 @@
-import matplotlib
-
-matplotlib.use("agg")
-
-import itertools
-
-import jax
-import jax.numpy as jnp
-import matplotlib.pyplot as plt
-import numpy as onp
-
-import nemos as nmo
-from nemos.basis import RaisedCosineBasisLinear
-from nemos.glm import GLM
-
-
-def test_set_up_glm():
- """Test the setup of the Generalized Linear Model (GLM).
-
- Returns
- -------
- GLM
- The simulated model of the Generalized Linear Model.
-
- Notes
- -----
- This function performs the setup for the Generalized Linear Model (GLM) by creating the necessary objects and variables.
- It generates a raised cosine basis, defines the simulated model using the basis functions, and returns the GLM object.
- """
- nn, nt, ws = 2, 1000, 100
- simulation_key = jax.random.PRNGKey(123)
-
- spike_basis = RaisedCosineBasisLinear(n_basis_funcs=5)
-
- B = spike_basis.evaluate(onp.linspace(0, 1, ws)).T
-
- w0 = onp.array([-0.1, -0.1, -0.2, -0.2, -1])
- w1 = onp.array([0, 0.1, 0.5, 0.1, 0])
-
- W = onp.empty((2, 5, 2))
- for i, j in itertools.product(range(nn), range(nn)):
- W[i, :, j] = w0 if (i == j) else w1
-
- simulated_model = GLM(B)
-
-
-def test_fit_glm2():
- jax.config.update("jax_platform_name", "cpu")
- jax.config.update("jax_enable_x64", True)
-
- nn, nt, ws = 2, 1000, 100
- simulation_key = jax.random.PRNGKey(123)
-
- spike_basis = RaisedCosineBasisLinear(n_basis_funcs=5)
-
- B = spike_basis.evaluate(onp.linspace(0, 1, ws)).T
-
- w0 = onp.array([-0.1, -0.1, -0.2, -0.2, -1])
- w1 = onp.array([0, 0.1, 0.5, 0.1, 0])
-
- W = onp.empty((2, 5, 2))
- for i, j in itertools.product(range(nn), range(nn)):
- W[i, :, j] = w0 if (i == j) else w1
-
- simulated_model = GLM(B)
- simulated_model.spike_basis_coeff_ = jnp.array(W)
- simulated_model.baseline_log_fr_ = jnp.ones(nn) * 0.1
-
- init_spikes = jnp.zeros((2, ws))
- spike_data = simulated_model.simulate(simulation_key, nt, init_spikes)
- sim_pred = simulated_model.predict(spike_data)
-
- fitted_model = GLM(
- B,
- solver_name="GradientDescent",
- solver_kwargs=dict(maxiter=1000, acceleration=False, verbose=True, stepsize=-1),
- )
-
- fitted_model.fit(spike_data)
- fit_pred = fitted_model.predict(spike_data)
-
- fig, axes = plt.subplots(2, 1)
- axes[0].plot(onp.arange(nt), spike_data[0])
- axes[0].plot(onp.arange(ws, nt + 1), sim_pred[0])
- axes[0].plot(onp.arange(ws, nt + 1), fit_pred[0])
- axes[1].plot(onp.arange(nt), spike_data[1])
- axes[1].plot(onp.arange(ws, nt + 1), sim_pred[1])
- axes[1].plot(onp.arange(ws, nt + 1), fit_pred[1])
-
- fig, axes = plt.subplots(nn, nn, sharey=True)
- for i, j in itertools.product(range(nn), range(nn)):
- axes[i, j].plot(B.T @ simulated_model.spike_basis_coeff_[i, :, j], label="true")
- axes[i, j].plot(B.T @ fitted_model.spike_basis_coeff_[i, :, j], label="est")
- axes[i, j].axhline(0, dashes=[2, 2], color="k")
- axes[-1, -1].legend()
- plt.close("all")
diff --git a/tests/test_glm_synthetic_single_neuron.py b/tests/test_glm_synthetic_single_neuron.py
deleted file mode 100644
index f1a3083e..00000000
--- a/tests/test_glm_synthetic_single_neuron.py
+++ /dev/null
@@ -1,55 +0,0 @@
-import matplotlib
-
-matplotlib.use("agg")
-
-import jax
-import jax.numpy as jnp
-import matplotlib.pyplot as plt
-import numpy as onp
-
-import nemos as nmo
-from nemos.basis import RaisedCosineBasisLog
-from nemos.glm import GLM
-
-
-def test_glm_fit():
- jax.config.update("jax_platform_name", "cpu")
- jax.config.update("jax_enable_x64", True)
-
- nn, nt, ws = 1, 1000, 100
- simulation_key = jax.random.PRNGKey(123)
-
- spike_basis = RaisedCosineBasisLog(n_basis_funcs=5)
- B = spike_basis.evaluate(onp.linspace(0, 1, ws)).T
-
- simulated_model = GLM(B)
- simulated_model.spike_basis_coeff_ = jnp.array([0, 0, -1, -1, -1])[None, :, None]
- simulated_model.baseline_log_fr_ = jnp.ones(nn) * 0.1
-
- init_spikes = jnp.zeros((nn, ws))
- spike_data = simulated_model.simulate(simulation_key, nt, init_spikes)
- sim_pred = simulated_model.predict(spike_data)
-
- fitted_model = GLM(
- B,
- solver_name="GradientDescent",
- solver_kwargs=dict(
- maxiter=1000, acceleration=False, verbose=True, stepsize=0.0
- ),
- )
-
- fitted_model.fit(spike_data)
- fit_pred = fitted_model.predict(spike_data)
-
- fig, ax = plt.subplots(1, 1)
- ax.plot(onp.arange(nt), spike_data[0])
- ax.plot(onp.arange(ws, nt + 1), sim_pred[0])
- ax.plot(onp.arange(ws, nt + 1), fit_pred[0])
-
- fig, ax = plt.subplots(1, 1, sharey=True)
- ax.plot(B.T @ simulated_model.spike_basis_coeff_[0, :, 0], label="true")
- ax.plot(B.T @ fitted_model.spike_basis_coeff_[0, :, 0], label="est")
- ax.axhline(0, dashes=[2, 2], color="k")
- ax.legend()
-
- plt.close("all")
diff --git a/tests/test_observation_models.py b/tests/test_observation_models.py
new file mode 100644
index 00000000..803b49fb
--- /dev/null
+++ b/tests/test_observation_models.py
@@ -0,0 +1,168 @@
+import jax
+import jax.numpy as jnp
+import numpy as np
+import pytest
+import scipy.stats as sts
+import statsmodels.api as sm
+
+import nemos as nmo
+
+
+@pytest.fixture()
+def poisson_observations():
+ return nmo.observation_models.PoissonObservations
+
+
+class TestPoissonObservations:
+ @pytest.mark.parametrize("link_function", [jnp.exp, jax.nn.softplus, 1])
+ def test_initialization_link_is_callable(self, link_function, poisson_observations):
+ """Check that the observation model initializes when a callable is passed."""
+ raise_exception = not callable(link_function)
+ if raise_exception:
+ with pytest.raises(
+ TypeError,
+ match="The `inverse_link_function` function must be a Callable",
+ ):
+ poisson_observations(link_function)
+ else:
+ poisson_observations(link_function)
+
+ @pytest.mark.parametrize(
+ "link_function", [jnp.exp, np.exp, lambda x: x, sm.families.links.log()]
+ )
+ def test_initialization_link_is_jax(self, link_function, poisson_observations):
+ """Check that the observation model initializes when a callable is passed."""
+ raise_exception = isinstance(link_function, np.ufunc) | isinstance(
+ link_function, sm.families.links.Link
+ )
+ if raise_exception:
+ with pytest.raises(
+ TypeError,
+ match="The `inverse_link_function` must return a jax.numpy.ndarray",
+ ):
+ poisson_observations(link_function)
+ else:
+ poisson_observations(link_function)
+
+ @pytest.mark.parametrize("link_function", [jnp.exp, jax.nn.softplus, 1])
+ def test_initialization_link_is_callable_set_params(
+ self, link_function, poisson_observations
+ ):
+ """Check that the observation model initializes when a callable is passed."""
+ observation_model = poisson_observations()
+ raise_exception = not callable(link_function)
+ if raise_exception:
+ with pytest.raises(
+ TypeError,
+ match="The `inverse_link_function` function must be a Callable",
+ ):
+ observation_model.set_params(inverse_link_function=link_function)
+ else:
+ observation_model.set_params(inverse_link_function=link_function)
+
+ @pytest.mark.parametrize(
+ "link_function", [jnp.exp, np.exp, lambda x: x, sm.families.links.log()]
+ )
+ def test_initialization_link_is_jax_set_params(
+ self, link_function, poisson_observations
+ ):
+ """Check that the observation model initializes when a callable is passed."""
+ raise_exception = isinstance(link_function, np.ufunc) | isinstance(
+ link_function, sm.families.links.Link
+ )
+ observation_model = poisson_observations()
+ if raise_exception:
+ with pytest.raises(
+ TypeError,
+ match="The `inverse_link_function` must return a jax.numpy.ndarray!",
+ ):
+ observation_model.set_params(inverse_link_function=link_function)
+ else:
+ observation_model.set_params(inverse_link_function=link_function)
+
+ @pytest.mark.parametrize(
+ "link_function",
+ [
+ jnp.exp,
+ lambda x: jnp.exp(x) if isinstance(x, jnp.ndarray) else "not a number",
+ ],
+ )
+ def test_initialization_link_returns_scalar(
+ self, link_function, poisson_observations
+ ):
+ """Check that the observation model initializes when a callable is passed."""
+ raise_exception = not isinstance(link_function(1.0), (jnp.ndarray, float))
+ observation_model = poisson_observations()
+ if raise_exception:
+ with pytest.raises(
+ TypeError,
+ match="The `inverse_link_function` must handle scalar inputs correctly",
+ ):
+ observation_model.set_params(inverse_link_function=link_function)
+ else:
+ observation_model.set_params(inverse_link_function=link_function)
+
+ def test_deviance_against_statsmodels(self, poissonGLM_model_instantiation):
+ """
+ Compare fitted parameters to statsmodels.
+ Assesses if the model estimates are close to statsmodels' results.
+ """
+ _, y, model, _, firing_rate = poissonGLM_model_instantiation
+ dev = sm.families.Poisson().deviance(y, firing_rate)
+ dev_model = model.observation_model.deviance(firing_rate, y).sum()
+ if not np.allclose(dev, dev_model):
+ raise ValueError("Deviance doesn't match statsmodels!")
+
+ def test_loglikelihood_against_scipy(self, poissonGLM_model_instantiation):
+ """
+ Compare log-likelihood to scipy.
+ Assesses if the model estimates are close to statsmodels' results.
+ """
+ _, y, model, _, firing_rate = poissonGLM_model_instantiation
+ ll_model = (
+ -model.observation_model.negative_log_likelihood(firing_rate, y).sum()
+ - jax.scipy.special.gammaln(y + 1).mean()
+ )
+ ll_scipy = sts.poisson(firing_rate).logpmf(y).mean()
+ if not np.allclose(ll_model, ll_scipy):
+ raise ValueError("Log-likelihood doesn't match scipy!")
+
+ @pytest.mark.parametrize("score_type", ["pseudo-r2-Cohen", "pseudo-r2-McFadden"])
+ def test_pseudo_r2_range(self, score_type, poissonGLM_model_instantiation):
+ """
+ Compute the pseudo-r2 and check that is < 1.
+ """
+ _, y, model, _, firing_rate = poissonGLM_model_instantiation
+ pseudo_r2 = model.observation_model.pseudo_r2(
+ firing_rate, y, score_type=score_type
+ )
+ if (pseudo_r2 > 1) or (pseudo_r2 < 0):
+ raise ValueError(f"pseudo-r2 of {pseudo_r2} outside the [0,1] range!")
+
+ @pytest.mark.parametrize("score_type", ["pseudo-r2-Cohen", "pseudo-r2-McFadden"])
+ def test_pseudo_r2_mean(self, score_type, poissonGLM_model_instantiation):
+ """
+ Check that the pseudo-r2 of the null model is 0.
+ """
+ _, y, model, _, _ = poissonGLM_model_instantiation
+ pseudo_r2 = model.observation_model.pseudo_r2(
+ y.mean(), y, score_type=score_type
+ )
+ if not np.allclose(pseudo_r2, 0):
+ raise ValueError(
+ f"pseudo-r2 of {pseudo_r2} for the null model. Should be equal to 0!"
+ )
+
+ def test_emission_probability(selfself, poissonGLM_model_instantiation):
+ """
+ Test the poisson emission probability.
+
+ Check that the emission probability is set to jax.random.poisson.
+ """
+ _, _, model, _, _ = poissonGLM_model_instantiation
+ key_array = jax.random.PRNGKey(123)
+ counts = model.observation_model.sample_generator(key_array, np.arange(1, 11))
+ if not jnp.all(counts == jax.random.poisson(key_array, np.arange(1, 11))):
+ raise ValueError(
+ "The emission probability should output the results of a call to jax.random.poisson."
+ )
diff --git a/tests/test_proximal_operator.py b/tests/test_proximal_operator.py
new file mode 100644
index 00000000..4addd1fa
--- /dev/null
+++ b/tests/test_proximal_operator.py
@@ -0,0 +1,52 @@
+import jax.numpy as jnp
+
+from nemos.proximal_operator import _vmap_norm2_masked_2, prox_group_lasso
+
+
+def test_prox_group_lasso_returns_tuple(example_data_prox_operator):
+ """Test whether prox_group_lasso returns a tuple."""
+ params, alpha, mask, scaling = example_data_prox_operator
+ updated_params = prox_group_lasso(params, alpha, mask, scaling)
+ assert isinstance(updated_params, tuple)
+
+
+def test_prox_group_lasso_tuple_length(example_data_prox_operator):
+ """Test whether the tuple returned by prox_group_lasso has a length of 2."""
+ params, alpha, mask, scaling = example_data_prox_operator
+ updated_params = prox_group_lasso(params, alpha, mask, scaling)
+ assert len(updated_params) == 2
+
+
+def test_prox_group_lasso_weights_shape(example_data_prox_operator):
+ """Test whether the shape of the weights in prox_group_lasso is correct."""
+ params, alpha, mask, scaling = example_data_prox_operator
+ updated_params = prox_group_lasso(params, alpha, mask, scaling)
+ assert updated_params[0].shape == params[0].shape
+
+
+def test_prox_group_lasso_intercepts_shape(example_data_prox_operator):
+ """Test whether the shape of the intercepts in prox_group_lasso is correct."""
+ params, alpha, mask, scaling = example_data_prox_operator
+ updated_params = prox_group_lasso(params, alpha, mask, scaling)
+ assert updated_params[1].shape == params[1].shape
+
+
+def test_vmap_norm2_masked_2_returns_array(example_data_prox_operator):
+ """Test whether _vmap_norm2_masked_2 returns a NumPy array."""
+ params, _, mask, _ = example_data_prox_operator
+ l2_norm = _vmap_norm2_masked_2(params[0], mask)
+ assert isinstance(l2_norm, jnp.ndarray)
+
+
+def test_vmap_norm2_masked_2_shape(example_data_prox_operator):
+ """Test whether the shape of the result from _vmap_norm2_masked_2 is correct."""
+ params, _, mask, _ = example_data_prox_operator
+ l2_norm = _vmap_norm2_masked_2(params[0], mask)
+ assert l2_norm.shape == (params[0].shape[0], mask.shape[0])
+
+
+def test_vmap_norm2_masked_2_non_negative(example_data_prox_operator):
+ """Test whether all elements of the result from _vmap_norm2_masked_2 are non-negative."""
+ params, _, mask, _ = example_data_prox_operator
+ l2_norm = _vmap_norm2_masked_2(params[0], mask)
+ assert jnp.all(l2_norm >= 0)
diff --git a/tests/test_regularizer.py b/tests/test_regularizer.py
new file mode 100644
index 00000000..b8a6fe1c
--- /dev/null
+++ b/tests/test_regularizer.py
@@ -0,0 +1,739 @@
+import jax
+import jax.numpy as jnp
+import numpy as np
+import pytest
+import statsmodels.api as sm
+from sklearn.linear_model import PoissonRegressor
+
+import nemos as nmo
+
+
+class TestUnRegularized:
+ cls = nmo.regularizer.UnRegularized
+
+ @pytest.mark.parametrize(
+ "solver_name",
+ ["GradientDescent", "BFGS", "ProximalGradient", "AGradientDescent", 1],
+ )
+ def test_init_solver_name(self, solver_name):
+ """Test UnRegularized acceptable solvers."""
+ acceptable_solvers = [
+ "GradientDescent",
+ "BFGS",
+ "LBFGS",
+ "ScipyMinimize",
+ "NonlinearCG",
+ "ScipyBoundedMinimize",
+ "LBFGSB",
+ ]
+ raise_exception = solver_name not in acceptable_solvers
+ if raise_exception:
+ with pytest.raises(
+ ValueError, match=f"Solver `{solver_name}` not allowed for "
+ ):
+ self.cls(solver_name)
+ else:
+ self.cls(solver_name)
+
+ @pytest.mark.parametrize(
+ "solver_name",
+ ["GradientDescent", "BFGS", "ProximalGradient", "AGradientDescent", 1],
+ )
+ def test_set_solver_name_allowed(self, solver_name):
+ """Test UnRegularized acceptable solvers."""
+ acceptable_solvers = [
+ "GradientDescent",
+ "BFGS",
+ "LBFGS",
+ "ScipyMinimize",
+ "NonlinearCG",
+ "ScipyBoundedMinimize",
+ "LBFGSB",
+ ]
+ regularizer = self.cls("GradientDescent")
+ raise_exception = solver_name not in acceptable_solvers
+ if raise_exception:
+ with pytest.raises(
+ ValueError, match=f"Solver `{solver_name}` not allowed for "
+ ):
+ regularizer.set_params(solver_name=solver_name)
+ else:
+ regularizer.set_params(solver_name=solver_name)
+
+ @pytest.mark.parametrize("solver_name", ["GradientDescent", "BFGS"])
+ @pytest.mark.parametrize("solver_kwargs", [{"tol": 10**-10}, {"tols": 10**-10}])
+ def test_init_solver_kwargs(self, solver_name, solver_kwargs):
+ """Test RidgeSolver acceptable kwargs."""
+
+ raise_exception = "tols" in list(solver_kwargs.keys())
+ if raise_exception:
+ with pytest.raises(
+ NameError, match="kwargs {'tols'} in solver_kwargs not a kwarg"
+ ):
+ self.cls(solver_name, solver_kwargs=solver_kwargs)
+ else:
+ self.cls(solver_name, solver_kwargs=solver_kwargs)
+
+ @pytest.mark.parametrize("loss", [lambda a, b, c: 0, 1, None, {}])
+ def test_loss_is_callable(self, loss):
+ """Test that the loss function is a callable"""
+ raise_exception = not callable(loss)
+ if raise_exception:
+ with pytest.raises(TypeError, match="The `loss` must be a Callable"):
+ self.cls("GradientDescent").instantiate_solver(loss)
+ else:
+ self.cls("GradientDescent").instantiate_solver(loss)
+
+ @pytest.mark.parametrize("solver_name", ["GradientDescent", "BFGS"])
+ def test_run_solver(self, solver_name, poissonGLM_model_instantiation):
+ """Test that the solver runs."""
+
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ runner = self.cls("GradientDescent").instantiate_solver(model._predict_and_compute_loss)
+ runner((true_params[0] * 0.0, true_params[1]), X, y)
+
+ def test_solver_output_match(self, poissonGLM_model_instantiation):
+ """Test that different solvers converge to the same solution."""
+ jax.config.update("jax_enable_x64", True)
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ # set precision to float64 for accurate matching of the results
+ model.data_type = jnp.float64
+ runner_gd = self.cls("GradientDescent", {"tol": 10**-12}).instantiate_solver(
+ model._predict_and_compute_loss
+ )
+ runner_bfgs = self.cls("BFGS", {"tol": 10**-12}).instantiate_solver(
+ model._predict_and_compute_loss
+ )
+ runner_scipy = self.cls(
+ "ScipyMinimize", {"method": "BFGS", "tol": 10**-12}
+ ).instantiate_solver(model._predict_and_compute_loss)
+ weights_gd, intercepts_gd = runner_gd(
+ (true_params[0] * 0.0, true_params[1]), X, y
+ )[0]
+ weights_bfgs, intercepts_bfgs = runner_bfgs(
+ (true_params[0] * 0.0, true_params[1]), X, y
+ )[0]
+ weights_scipy, intercepts_scipy = runner_scipy(
+ (true_params[0] * 0.0, true_params[1]), X, y
+ )[0]
+
+ match_weights = np.allclose(weights_gd, weights_bfgs) and np.allclose(
+ weights_gd, weights_scipy
+ )
+ match_intercepts = np.allclose(intercepts_gd, intercepts_bfgs) and np.allclose(
+ intercepts_gd, intercepts_scipy
+ )
+ if (not match_weights) or (not match_intercepts):
+ raise ValueError(
+ "Convex estimators should converge to the same numerical value."
+ )
+
+ def test_solver_match_sklearn(self, poissonGLM_model_instantiation):
+ """Test that different solvers converge to the same solution."""
+ jax.config.update("jax_enable_x64", True)
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ # set precision to float64 for accurate matching of the results
+ model.data_type = jnp.float64
+ regularizer = self.cls("GradientDescent", {"tol": 10**-12})
+ runner_bfgs = regularizer.instantiate_solver(model._predict_and_compute_loss)
+ weights_bfgs, intercepts_bfgs = runner_bfgs(
+ (true_params[0] * 0.0, true_params[1]), X, y
+ )[0]
+ model_skl = PoissonRegressor(fit_intercept=True, tol=10**-12, alpha=0.0)
+ model_skl.fit(X[:, 0], y[:, 0])
+
+ match_weights = np.allclose(model_skl.coef_, weights_bfgs.flatten())
+ match_intercepts = np.allclose(model_skl.intercept_, intercepts_bfgs.flatten())
+ if (not match_weights) or (not match_intercepts):
+ raise ValueError("Ridge GLM regularizer estimate does not match sklearn!")
+
+
+class TestRidge:
+ cls = nmo.regularizer.Ridge
+
+ @pytest.mark.parametrize(
+ "solver_name",
+ ["GradientDescent", "BFGS", "ProximalGradient", "AGradientDescent", 1],
+ )
+ def test_init_solver_name(self, solver_name):
+ """Test RidgeSolver acceptable solvers."""
+ acceptable_solvers = [
+ "GradientDescent",
+ "BFGS",
+ "LBFGS",
+ "ScipyMinimize",
+ "NonlinearCG",
+ "ScipyBoundedMinimize",
+ "LBFGSB",
+ ]
+ raise_exception = solver_name not in acceptable_solvers
+ if raise_exception:
+ with pytest.raises(
+ ValueError, match=f"Solver `{solver_name}` not allowed for "
+ ):
+ self.cls(solver_name)
+ else:
+ self.cls(solver_name)
+
+ @pytest.mark.parametrize(
+ "solver_name",
+ ["GradientDescent", "BFGS", "ProximalGradient", "AGradientDescent", 1],
+ )
+ def test_set_solver_name_allowed(self, solver_name):
+ """Test RidgeSolver acceptable solvers."""
+ acceptable_solvers = [
+ "GradientDescent",
+ "BFGS",
+ "LBFGS",
+ "ScipyMinimize",
+ "NonlinearCG",
+ "ScipyBoundedMinimize",
+ "LBFGSB",
+ ]
+ regularizer = self.cls("GradientDescent")
+ raise_exception = solver_name not in acceptable_solvers
+ if raise_exception:
+ with pytest.raises(
+ ValueError, match=f"Solver `{solver_name}` not allowed for "
+ ):
+ regularizer.set_params(solver_name=solver_name)
+ else:
+ regularizer.set_params(solver_name=solver_name)
+
+ @pytest.mark.parametrize("solver_name", ["GradientDescent", "BFGS"])
+ @pytest.mark.parametrize("solver_kwargs", [{"tol": 10**-10}, {"tols": 10**-10}])
+ def test_init_solver_kwargs(self, solver_name, solver_kwargs):
+ """Test Ridge acceptable kwargs."""
+
+ raise_exception = "tols" in list(solver_kwargs.keys())
+ if raise_exception:
+ with pytest.raises(
+ NameError, match="kwargs {'tols'} in solver_kwargs not a kwarg"
+ ):
+ self.cls(solver_name, solver_kwargs=solver_kwargs)
+ else:
+ self.cls(solver_name, solver_kwargs=solver_kwargs)
+
+ @pytest.mark.parametrize("loss", [lambda a, b, c: 0, 1, None, {}])
+ def test_loss_is_callable(self, loss):
+ """Test that the loss function is a callable"""
+ raise_exception = not callable(loss)
+ if raise_exception:
+ with pytest.raises(TypeError, match="The `loss` must be a Callable"):
+ self.cls("GradientDescent").instantiate_solver(loss)
+ else:
+ self.cls("GradientDescent").instantiate_solver(loss)
+
+ @pytest.mark.parametrize("solver_name", ["GradientDescent", "BFGS"])
+ def test_run_solver(self, solver_name, poissonGLM_model_instantiation):
+ """Test that the solver runs."""
+
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ runner = self.cls("GradientDescent").instantiate_solver(model._predict_and_compute_loss)
+ runner((true_params[0] * 0.0, true_params[1]), X, y)
+
+ def test_solver_output_match(self, poissonGLM_model_instantiation):
+ """Test that different solvers converge to the same solution."""
+ jax.config.update("jax_enable_x64", True)
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ # set precision to float64 for accurate matching of the results
+ model.data_type = jnp.float64
+ runner_gd = self.cls("GradientDescent", {"tol": 10**-12}).instantiate_solver(
+ model._predict_and_compute_loss
+ )
+ runner_bfgs = self.cls("BFGS", {"tol": 10**-12}).instantiate_solver(
+ model._predict_and_compute_loss
+ )
+ runner_scipy = self.cls(
+ "ScipyMinimize", {"method": "BFGS", "tol": 10**-12}
+ ).instantiate_solver(model._predict_and_compute_loss)
+ weights_gd, intercepts_gd = runner_gd(
+ (true_params[0] * 0.0, true_params[1]), X, y
+ )[0]
+ weights_bfgs, intercepts_bfgs = runner_bfgs(
+ (true_params[0] * 0.0, true_params[1]), X, y
+ )[0]
+ weights_scipy, intercepts_scipy = runner_scipy(
+ (true_params[0] * 0.0, true_params[1]), X, y
+ )[0]
+
+ match_weights = np.allclose(weights_gd, weights_bfgs) and np.allclose(
+ weights_gd, weights_scipy
+ )
+ match_intercepts = np.allclose(intercepts_gd, intercepts_bfgs) and np.allclose(
+ intercepts_gd, intercepts_scipy
+ )
+ if (not match_weights) or (not match_intercepts):
+ raise ValueError(
+ "Convex estimators should converge to the same numerical value."
+ )
+
+ def test_solver_match_sklearn(self, poissonGLM_model_instantiation):
+ """Test that different solvers converge to the same solution."""
+ jax.config.update("jax_enable_x64", True)
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ # set precision to float64 for accurate matching of the results
+ model.data_type = jnp.float64
+ regularizer = self.cls("GradientDescent", {"tol": 10**-12})
+ runner_bfgs = regularizer.instantiate_solver(model._predict_and_compute_loss)
+ weights_bfgs, intercepts_bfgs = runner_bfgs(
+ (true_params[0] * 0.0, true_params[1]), X, y
+ )[0]
+ model_skl = PoissonRegressor(
+ fit_intercept=True, tol=10**-12, alpha=regularizer.regularizer_strength
+ )
+ model_skl.fit(X[:, 0], y[:, 0])
+
+ match_weights = np.allclose(model_skl.coef_, weights_bfgs.flatten())
+ match_intercepts = np.allclose(model_skl.intercept_, intercepts_bfgs.flatten())
+ if (not match_weights) or (not match_intercepts):
+ raise ValueError("Ridge GLM solver estimate does not match sklearn!")
+
+
+class TestLasso:
+ cls = nmo.regularizer.Lasso
+
+ @pytest.mark.parametrize(
+ "solver_name",
+ ["GradientDescent", "BFGS", "ProximalGradient", "AGradientDescent", 1],
+ )
+ def test_init_solver_name(self, solver_name):
+ """Test Lasso acceptable solvers."""
+ acceptable_solvers = ["ProximalGradient"]
+ raise_exception = solver_name not in acceptable_solvers
+ if raise_exception:
+ with pytest.raises(
+ ValueError, match=f"Solver `{solver_name}` not allowed for "
+ ):
+ self.cls(solver_name)
+ else:
+ self.cls(solver_name)
+
+ @pytest.mark.parametrize(
+ "solver_name",
+ ["GradientDescent", "BFGS", "ProximalGradient", "AGradientDescent", 1],
+ )
+ def test_set_solver_name_allowed(self, solver_name):
+ """Test Lasso acceptable solvers."""
+ acceptable_solvers = ["ProximalGradient"]
+ regularizer = self.cls("ProximalGradient")
+ raise_exception = solver_name not in acceptable_solvers
+ if raise_exception:
+ with pytest.raises(
+ ValueError, match=f"Solver `{solver_name}` not allowed for "
+ ):
+ regularizer.set_params(solver_name=solver_name)
+ else:
+ regularizer.set_params(solver_name=solver_name)
+
+ @pytest.mark.parametrize("solver_kwargs", [{"tol": 10**-10}, {"tols": 10**-10}])
+ def test_init_solver_kwargs(self, solver_kwargs):
+ """Test LassoSolver acceptable kwargs."""
+ raise_exception = "tols" in list(solver_kwargs.keys())
+ if raise_exception:
+ with pytest.raises(
+ NameError, match="kwargs {'tols'} in solver_kwargs not a kwarg"
+ ):
+ self.cls("ProximalGradient", solver_kwargs=solver_kwargs)
+ else:
+ self.cls("ProximalGradient", solver_kwargs=solver_kwargs)
+
+ @pytest.mark.parametrize("loss", [lambda a, b, c: 0, 1, None, {}])
+ def test_loss_callable(self, loss):
+ """Test that the loss function is a callable"""
+ raise_exception = not callable(loss)
+ if raise_exception:
+ with pytest.raises(TypeError, match="The `loss` must be a Callable"):
+ self.cls("ProximalGradient").instantiate_solver(loss)
+ else:
+ self.cls("ProximalGradient").instantiate_solver(loss)
+
+ def test_run_solver(self, poissonGLM_model_instantiation):
+ """Test that the solver runs."""
+
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ runner = self.cls("ProximalGradient").instantiate_solver(model._predict_and_compute_loss)
+ runner((true_params[0] * 0.0, true_params[1]), X, y)
+
+ def test_solver_match_statsmodels(self, poissonGLM_model_instantiation):
+ """Test that different solvers converge to the same solution."""
+ jax.config.update("jax_enable_x64", True)
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ # set precision to float64 for accurate matching of the results
+ model.data_type = jnp.float64
+ regularizer = self.cls("ProximalGradient", {"tol": 10**-12})
+ runner = regularizer.instantiate_solver(model._predict_and_compute_loss)
+ weights, intercepts = runner((true_params[0] * 0.0, true_params[1]), X, y)[0]
+
+ # instantiate the glm with statsmodels
+ glm_sm = sm.GLM(
+ endog=y[:, 0], exog=sm.add_constant(X[:, 0]), family=sm.families.Poisson()
+ )
+
+ # regularize everything except intercept
+ alpha_sm = np.ones(X.shape[2] + 1) * regularizer.regularizer_strength
+ alpha_sm[0] = 0
+
+ # pure lasso = elastic net with L1 weight = 1
+ res_sm = glm_sm.fit_regularized(
+ method="elastic_net", alpha=alpha_sm, L1_wt=1.0, cnvrg_tol=10**-12
+ )
+ # compare params
+ sm_params = res_sm.params
+ glm_params = jnp.hstack((intercepts, weights.flatten()))
+ match_weights = np.allclose(sm_params, glm_params)
+ if not match_weights:
+ raise ValueError("Lasso GLM solver estimate does not match statsmodels!")
+
+
+class TestGroupLasso:
+ cls = nmo.regularizer.GroupLasso
+
+ @pytest.mark.parametrize(
+ "solver_name",
+ ["GradientDescent", "BFGS", "ProximalGradient", "AGradientDescent", 1],
+ )
+ def test_init_solver_name(self, solver_name):
+ """Test GroupLasso acceptable solvers."""
+ acceptable_solvers = ["ProximalGradient"]
+ raise_exception = solver_name not in acceptable_solvers
+
+ # create a valid mask
+ mask = np.zeros((2, 10))
+ mask[0, :5] = 1
+ mask[1, 5:] = 1
+ mask = jnp.asarray(mask)
+
+ if raise_exception:
+ with pytest.raises(
+ ValueError, match=f"Solver `{solver_name}` not allowed for "
+ ):
+ self.cls(solver_name, mask)
+ else:
+ self.cls(solver_name, mask)
+
+ @pytest.mark.parametrize(
+ "solver_name",
+ ["GradientDescent", "BFGS", "ProximalGradient", "AGradientDescent", 1],
+ )
+ def test_set_solver_name_allowed(self, solver_name):
+ """Test GroupLassoSolver acceptable solvers."""
+ acceptable_solvers = ["ProximalGradient"]
+ # create a valid mask
+ mask = np.zeros((2, 10))
+ mask[0, :5] = 1
+ mask[1, 5:] = 1
+ mask = jnp.asarray(mask)
+ regularizer = self.cls("ProximalGradient", mask=mask)
+ raise_exception = solver_name not in acceptable_solvers
+ if raise_exception:
+ with pytest.raises(
+ ValueError, match=f"Solver `{solver_name}` not allowed for "
+ ):
+ regularizer.set_params(solver_name=solver_name)
+ else:
+ regularizer.set_params(solver_name=solver_name)
+
+ @pytest.mark.parametrize("solver_kwargs", [{"tol": 10**-10}, {"tols": 10**-10}])
+ def test_init_solver_kwargs(self, solver_kwargs):
+ """Test GroupLasso acceptable kwargs."""
+ raise_exception = "tols" in list(solver_kwargs.keys())
+
+ # create a valid mask
+ mask = np.zeros((2, 10))
+ mask[0, :5] = 1
+ mask[0, 1:] = 1
+ mask = jnp.asarray(mask)
+
+ if raise_exception:
+ with pytest.raises(
+ NameError, match="kwargs {'tols'} in solver_kwargs not a kwarg"
+ ):
+ self.cls("ProximalGradient", mask, solver_kwargs=solver_kwargs)
+ else:
+ self.cls("ProximalGradient", mask, solver_kwargs=solver_kwargs)
+
+ @pytest.mark.parametrize("loss", [lambda a, b, c: 0, 1, None, {}])
+ def test_loss_callable(self, loss):
+ """Test that the loss function is a callable"""
+ raise_exception = not callable(loss)
+
+ # create a valid mask
+ mask = np.zeros((2, 10))
+ mask[0, :5] = 1
+ mask[1, 5:] = 1
+ mask = jnp.asarray(mask)
+
+ if raise_exception:
+ with pytest.raises(TypeError, match="The `loss` must be a Callable"):
+ self.cls("ProximalGradient", mask).instantiate_solver(loss)
+ else:
+ self.cls("ProximalGradient", mask).instantiate_solver(loss)
+
+ def test_run_solver(self, poissonGLM_model_instantiation):
+ """Test that the solver runs."""
+
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+
+ # create a valid mask
+ mask = np.zeros((2, X.shape[2]))
+ mask[0, :2] = 1
+ mask[1, 2:] = 1
+ mask = jnp.asarray(mask)
+
+ runner = self.cls("ProximalGradient", mask).instantiate_solver(model._predict_and_compute_loss)
+ runner((true_params[0] * 0.0, true_params[1]), X, y)
+
+ @pytest.mark.parametrize("n_groups_assign", [0, 1, 2])
+ def test_mask_validity_groups(
+ self, n_groups_assign, group_sparse_poisson_glm_model_instantiation
+ ):
+ """Test that mask assigns at most 1 group to each weight."""
+ raise_exception = n_groups_assign > 1
+ (
+ X,
+ y,
+ model,
+ true_params,
+ firing_rate,
+ _,
+ ) = group_sparse_poisson_glm_model_instantiation
+
+ # create a valid mask
+ mask = np.zeros((2, X.shape[2]))
+ mask[0, :2] = 1
+ mask[1, 2:] = 1
+
+ # change assignment
+ if n_groups_assign == 0:
+ mask[:, 3] = 0
+ elif n_groups_assign == 2:
+ mask[:, 3] = 1
+
+ mask = jnp.asarray(mask)
+
+ if raise_exception:
+ with pytest.raises(
+ ValueError, match="Incorrect group assignment. " "Some of the features"
+ ):
+ self.cls("ProximalGradient", mask).instantiate_solver(model._predict_and_compute_loss)
+ else:
+ self.cls("ProximalGradient", mask).instantiate_solver(model._predict_and_compute_loss)
+
+ @pytest.mark.parametrize("set_entry", [0, 1, -1, 2, 2.5])
+ def test_mask_validity_entries(self, set_entry, poissonGLM_model_instantiation):
+ """Test that mask is composed of 0s and 1s."""
+ raise_exception = set_entry not in {0, 1}
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+
+ # create a valid mask
+ mask = np.zeros((2, X.shape[2]))
+ mask[0, :2] = 1
+ mask[1, 2:] = 1
+ # assign an entry
+ mask[1, 2] = set_entry
+ mask = jnp.asarray(mask, dtype=jnp.float32)
+
+ if raise_exception:
+ with pytest.raises(ValueError, match="Mask elements be 0s and 1s"):
+ self.cls("ProximalGradient", mask).instantiate_solver(model._predict_and_compute_loss)
+ else:
+ self.cls("ProximalGradient", mask).instantiate_solver(model._predict_and_compute_loss)
+
+ @pytest.mark.parametrize("n_dim", [0, 1, 2, 3])
+ def test_mask_dimension(self, n_dim, poissonGLM_model_instantiation):
+ """Test that mask is composed of 0s and 1s."""
+
+ raise_exception = n_dim != 2
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+
+ # create a valid mask
+ if n_dim == 0:
+ mask = np.array([])
+ elif n_dim == 1:
+ mask = np.ones((1,))
+ elif n_dim == 2:
+ mask = np.zeros((2, X.shape[2]))
+ mask[0, :2] = 1
+ mask[1, 2:] = 1
+ else:
+ mask = np.zeros((2, X.shape[2]) + (1,) * (n_dim - 2))
+ mask[0, :2] = 1
+ mask[1, 2:] = 1
+
+ mask = jnp.asarray(mask, dtype=jnp.float32)
+
+ if raise_exception:
+ with pytest.raises(ValueError, match="`mask` must be 2-dimensional"):
+ self.cls("ProximalGradient", mask).instantiate_solver(model._predict_and_compute_loss)
+ else:
+ self.cls("ProximalGradient", mask).instantiate_solver(model._predict_and_compute_loss)
+
+ @pytest.mark.parametrize("n_groups", [0, 1, 2])
+ def test_mask_n_groups(self, n_groups, poissonGLM_model_instantiation):
+ """Test that mask has at least 1 group."""
+ raise_exception = n_groups < 1
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+
+ # create a mask
+ mask = np.zeros((n_groups, X.shape[2]))
+ if n_groups > 0:
+ for i in range(n_groups - 1):
+ mask[i, i : i + 1] = 1
+ mask[-1, n_groups - 1 :] = 1
+
+ mask = jnp.asarray(mask, dtype=jnp.float32)
+
+ if raise_exception:
+ with pytest.raises(ValueError, match=r"Empty mask provided! Mask has "):
+ self.cls("ProximalGradient", mask).instantiate_solver(model._predict_and_compute_loss)
+ else:
+ self.cls("ProximalGradient", mask).instantiate_solver(model._predict_and_compute_loss)
+
+ def test_group_sparsity_enforcement(
+ self, group_sparse_poisson_glm_model_instantiation
+ ):
+ """Test that group lasso works on a simple dataset."""
+ (
+ X,
+ y,
+ model,
+ true_params,
+ firing_rate,
+ _,
+ ) = group_sparse_poisson_glm_model_instantiation
+ zeros_true = true_params[0].flatten() == 0
+ mask = np.zeros((2, X.shape[2]))
+ mask[0, zeros_true] = 1
+ mask[1, ~zeros_true] = 1
+ mask = jnp.asarray(mask, dtype=jnp.float32)
+
+ runner = self.cls("ProximalGradient", mask).instantiate_solver(model._predict_and_compute_loss)
+ params, _ = runner((true_params[0] * 0.0, true_params[1]), X, y)
+
+ zeros_est = params[0] == 0
+ if not np.all(zeros_est == zeros_true):
+ raise ValueError("GroupLasso failed to zero-out the parameter group!")
+
+ ###########
+ # Test mask from set_params
+ ###########
+ @pytest.mark.parametrize("n_groups_assign", [0, 1, 2])
+ def test_mask_validity_groups_set_params(
+ self, n_groups_assign, group_sparse_poisson_glm_model_instantiation
+ ):
+ """Test that mask assigns at most 1 group to each weight."""
+ raise_exception = n_groups_assign > 1
+ (
+ X,
+ y,
+ model,
+ true_params,
+ firing_rate,
+ _,
+ ) = group_sparse_poisson_glm_model_instantiation
+
+ # create a valid mask
+ mask = np.zeros((2, X.shape[2]))
+ mask[0, :2] = 1
+ mask[1, 2:] = 1
+ regularizer = self.cls("ProximalGradient", mask)
+
+ # change assignment
+ if n_groups_assign == 0:
+ mask[:, 3] = 0
+ elif n_groups_assign == 2:
+ mask[:, 3] = 1
+
+ mask = jnp.asarray(mask)
+
+ if raise_exception:
+ with pytest.raises(
+ ValueError, match="Incorrect group assignment. " "Some of the features"
+ ):
+ regularizer.set_params(mask=mask)
+ else:
+ regularizer.set_params(mask=mask)
+
+ @pytest.mark.parametrize("set_entry", [0, 1, -1, 2, 2.5])
+ def test_mask_validity_entries_set_params(
+ self, set_entry, poissonGLM_model_instantiation
+ ):
+ """Test that mask is composed of 0s and 1s."""
+ raise_exception = set_entry not in {0, 1}
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+
+ # create a valid mask
+ mask = np.zeros((2, X.shape[2]))
+ mask[0, :2] = 1
+ mask[1, 2:] = 1
+ regularizer = self.cls("ProximalGradient", mask)
+
+ # assign an entry
+ mask[1, 2] = set_entry
+ mask = jnp.asarray(mask, dtype=jnp.float32)
+
+ if raise_exception:
+ with pytest.raises(ValueError, match="Mask elements be 0s and 1s"):
+ regularizer.set_params(mask=mask)
+ else:
+ regularizer.set_params(mask=mask)
+
+ @pytest.mark.parametrize("n_dim", [0, 1, 2, 3])
+ def test_mask_dimension(self, n_dim, poissonGLM_model_instantiation):
+ """Test that mask is composed of 0s and 1s."""
+
+ raise_exception = n_dim != 2
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+
+ valid_mask = np.zeros((2, X.shape[2]))
+ valid_mask[0, :1] = 1
+ valid_mask[1, 1:] = 1
+ regularizer = self.cls("ProximalGradient", valid_mask)
+
+ # create a mask
+ if n_dim == 0:
+ mask = np.array([])
+ elif n_dim == 1:
+ mask = np.ones((1,))
+ elif n_dim == 2:
+ mask = np.zeros((2, X.shape[2]))
+ mask[0, :2] = 1
+ mask[1, 2:] = 1
+ else:
+ mask = np.zeros((2, X.shape[2]) + (1,) * (n_dim - 2))
+ mask[0, :2] = 1
+ mask[1, 2:] = 1
+
+ mask = jnp.asarray(mask, dtype=jnp.float32)
+
+ if raise_exception:
+ with pytest.raises(ValueError, match="`mask` must be 2-dimensional"):
+ regularizer.set_params(mask=mask)
+ else:
+ regularizer.set_params(mask=mask)
+
+ @pytest.mark.parametrize("n_groups", [0, 1, 2])
+ def test_mask_n_groups_set_params(self, n_groups, poissonGLM_model_instantiation):
+ """Test that mask has at least 1 group."""
+ raise_exception = n_groups < 1
+ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
+ valid_mask = np.zeros((2, X.shape[2]))
+ valid_mask[0, :1] = 1
+ valid_mask[1, 1:] = 1
+ regularizer = self.cls("ProximalGradient", valid_mask)
+
+ # create a mask
+ mask = np.zeros((n_groups, X.shape[2]))
+ if n_groups > 0:
+ for i in range(n_groups - 1):
+ mask[i, i : i + 1] = 1
+ mask[-1, n_groups - 1 :] = 1
+
+ mask = jnp.asarray(mask, dtype=jnp.float32)
+
+ if raise_exception:
+ with pytest.raises(ValueError, match=r"Empty mask provided! Mask has "):
+ regularizer.set_params(mask=mask)
+ else:
+ regularizer.set_params(mask=mask)
diff --git a/tests/test_simulation.py b/tests/test_simulation.py
new file mode 100644
index 00000000..53d8ff42
--- /dev/null
+++ b/tests/test_simulation.py
@@ -0,0 +1,169 @@
+import itertools
+from contextlib import nullcontext as does_not_raise
+
+import numpy as np
+import pytest
+
+import nemos.basis as basis
+import nemos.simulation as simulation
+
+
+@pytest.mark.parametrize(
+ "inhib_a, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="Gamma parameter [a-z]+_[a,b] must be >0.")),
+ (0, pytest.raises(ValueError, match="Gamma parameter [a-z]+_[a,b] must be >0.")),
+ (1, does_not_raise()),
+ ],
+ )
+def test_difference_of_gammas_inhib_a(inhib_a, expectation):
+ with expectation:
+ simulation.difference_of_gammas(10, inhib_a=inhib_a)
+
+
+@pytest.mark.parametrize(
+ "excit_a, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="Gamma parameter [a-z]+_[a,b] must be >0.")),
+ (0, pytest.raises(ValueError, match="Gamma parameter [a-z]+_[a,b] must be >0.")),
+ (1, does_not_raise()),
+ ],
+ )
+def test_difference_of_gammas_excit_a(excit_a, expectation):
+ with expectation:
+ simulation.difference_of_gammas(10, excit_a=excit_a)
+
+
+@pytest.mark.parametrize(
+ "inhib_b, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="Gamma parameter [a-z]+_[a,b] must be >0.")),
+ (0, pytest.raises(ValueError, match="Gamma parameter [a-z]+_[a,b] must be >0.")),
+ (1, does_not_raise()),
+ ],
+ )
+def test_difference_of_gammas_excit_a(inhib_b, expectation):
+ with expectation:
+ simulation.difference_of_gammas(10, inhib_b=inhib_b)
+
+
+@pytest.mark.parametrize(
+ "excit_b, expectation",
+ [
+ (-1, pytest.raises(ValueError, match="Gamma parameter [a-z]+_[a,b] must be >0.")),
+ (0, pytest.raises(ValueError, match="Gamma parameter [a-z]+_[a,b] must be >0.")),
+ (1, does_not_raise()),
+ ],
+ )
+def test_difference_of_gammas_excit_a(excit_b, expectation):
+ with expectation:
+ simulation.difference_of_gammas(10, excit_b=excit_b)
+
+
+@pytest.mark.parametrize(
+ "upper_percentile, expectation",
+ [
+ (-0.1, pytest.raises(ValueError, match=r"upper_percentile should lie in the \[0, 1\) interval.")),
+ (0, does_not_raise()),
+ (0.1, does_not_raise()),
+ (1, pytest.raises(ValueError, match=r"upper_percentile should lie in the \[0, 1\) interval.")),
+ (10, pytest.raises(ValueError, match=r"upper_percentile should lie in the \[0, 1\) interval.")),
+ ],
+ )
+def test_difference_of_gammas_percentile_params(upper_percentile, expectation):
+ with expectation:
+ simulation.difference_of_gammas(10, upper_percentile)
+
+
+@pytest.mark.parametrize("window_size", [0, 1, 2])
+def test_difference_of_gammas_output_shape(window_size):
+ result_size = simulation.difference_of_gammas(window_size).size
+ assert result_size == window_size, f"Expected output size {window_size}, but got {result_size}"
+
+
+@pytest.mark.parametrize("window_size", [1, 2, 10])
+def test_difference_of_gammas_output_norm(window_size):
+ result = simulation.difference_of_gammas(window_size)
+ assert np.allclose(np.linalg.norm(result, ord=2),1), "The output of difference_of_gammas is not unit norm."
+
+
+@pytest.mark.parametrize(
+ "coupling_filters, expectation",
+ [
+ (np.zeros((10, )), pytest.raises(ValueError, match=r"coupling_filters must be a 3 dimensional array")),
+ (np.zeros((10, 2)), pytest.raises(ValueError, match=r"coupling_filters must be a 3 dimensional array")),
+ (np.zeros((10, 2, 2)), does_not_raise()),
+ (np.zeros((10, 2, 2, 2)), pytest.raises(ValueError, match=r"coupling_filters must be a 3 dimensional array"))
+ ],
+ )
+def test_regress_filter_coupling_filters_dim(coupling_filters, expectation):
+ ws = coupling_filters.shape[0]
+ with expectation:
+ simulation.regress_filter(coupling_filters, np.zeros((ws, 3)))
+
+
+@pytest.mark.parametrize(
+ "eval_basis, expectation",
+ [
+ (np.zeros((10, )), pytest.raises(ValueError, match=r"eval_basis must be a 2 dimensional array")),
+ (np.zeros((10, 2)), does_not_raise()),
+ (np.zeros((10, 2, 2)), pytest.raises(ValueError, match=r"eval_basis must be a 2 dimensional array")),
+ (np.zeros((10, 2, 2, 2)), pytest.raises(ValueError, match=r"eval_basis must be a 2 dimensional array"))
+ ],
+ )
+def test_regress_filter_eval_basis_dim(eval_basis, expectation):
+ ws = eval_basis.shape[0]
+ with expectation:
+ simulation.regress_filter(np.zeros((ws, 1, 1)), eval_basis)
+
+
+@pytest.mark.parametrize(
+ "delta_ws, expectation",
+ [
+ (-1, pytest.raises(ValueError, match=r"window_size mismatch\. The window size of ")),
+ (0, does_not_raise()),
+ (1, pytest.raises(ValueError, match=r"window_size mismatch\. The window size of ")),
+ ],
+ )
+def test_regress_filter_window_size_matching(delta_ws, expectation):
+ ws = 2
+ with expectation:
+ simulation.regress_filter(np.zeros((ws, 1, 1)), np.zeros((ws + delta_ws, 1)))
+
+
+@pytest.mark.parametrize(
+ "window_size, n_neurons_sender, n_neurons_receiver, n_basis_funcs",
+ [x for x in itertools.product([1, 2], [1, 2], [1, 2], [1, 2])],
+ )
+def test_regress_filter_weights_size(window_size, n_neurons_sender, n_neurons_receiver, n_basis_funcs):
+ weights = simulation.regress_filter(
+ np.zeros((window_size, n_neurons_sender, n_neurons_receiver)),
+ np.zeros((window_size, n_basis_funcs))
+ )
+ assert weights.shape[0] == n_neurons_sender, (f"First dimension of weights (n_neurons_receiver) does not "
+ f"match the second dimension of coupling_filters.")
+ assert weights.shape[1] == n_neurons_receiver, (f"Second dimension of weights (n_neuron_sender) does not "
+ f"match the third dimension of coupling_filters.")
+ assert weights.shape[2] == n_basis_funcs, (f"Third dimension of weights (n_basis_funcs) does not "
+ f"match the second dimension of eval_basis.")
+
+
+def test_least_square_correctness():
+ """
+ Test the correctness of the least square estimate by enforcing an invertible map,
+ i.e. a map for which the least-square estimator matches the original weights.
+ """
+ # set up problem dimensionality
+ ws, n_neurons_receiver, n_neurons_sender, n_basis_funcs = 100, 1, 2, 10
+ # evaluate a basis
+ _, eval_basis = basis.RaisedCosineBasisLog(n_basis_funcs).evaluate_on_grid(ws)
+ # generate random weights to define filters
+ weights = np.random.normal(size=(n_neurons_receiver, n_neurons_sender, n_basis_funcs))
+ # define filters as linear combination of basis elements
+ coupling_filt = np.einsum("ijk, tk -> tij", weights, eval_basis)
+ # recover weights by means of linear regression
+ weights_lsq = simulation.regress_filter(coupling_filt, eval_basis)
+ # check the exact matching of the filters up to numerical error
+ assert np.allclose(weights_lsq, weights)
+
+
diff --git a/tox.ini b/tox.ini
index df1d07c2..96e65259 100644
--- a/tox.ini
+++ b/tox.ini
@@ -14,7 +14,8 @@ package_cache = .tox/cache
# while black, isort and flake8 are also i
commands =
black --check src
- isort --check src
+ isort src
+ isort docs/examples
flake8 --config={toxinidir}/tox.ini src
pytest