From 2a815b4d0f0f5ebaed812c947867fcd6fab24e57 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 14 Dec 2023 11:25:02 -0500 Subject: [PATCH] added test of correctness for the lsq --- tests/test_simulation.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 5f8fba51..53d8ff42 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -4,6 +4,7 @@ import numpy as np import pytest +import nemos.basis as basis import nemos.simulation as simulation @@ -145,3 +146,24 @@ def test_regress_filter_weights_size(window_size, n_neurons_sender, n_neurons_re f"match the third dimension of coupling_filters.") assert weights.shape[2] == n_basis_funcs, (f"Third dimension of weights (n_basis_funcs) does not " f"match the second dimension of eval_basis.") + + +def test_least_square_correctness(): + """ + Test the correctness of the least square estimate by enforcing an invertible map, + i.e. a map for which the least-square estimator matches the original weights. + """ + # set up problem dimensionality + ws, n_neurons_receiver, n_neurons_sender, n_basis_funcs = 100, 1, 2, 10 + # evaluate a basis + _, eval_basis = basis.RaisedCosineBasisLog(n_basis_funcs).evaluate_on_grid(ws) + # generate random weights to define filters + weights = np.random.normal(size=(n_neurons_receiver, n_neurons_sender, n_basis_funcs)) + # define filters as linear combination of basis elements + coupling_filt = np.einsum("ijk, tk -> tij", weights, eval_basis) + # recover weights by means of linear regression + weights_lsq = simulation.regress_filter(coupling_filt, eval_basis) + # check the exact matching of the filters up to numerical error + assert np.allclose(weights_lsq, weights) + +