Skip to content

Commit

Permalink
added test of correctness for the lsq
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Dec 14, 2023
1 parent 46092ed commit 2a815b4
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions tests/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pytest

import nemos.basis as basis
import nemos.simulation as simulation


Expand Down Expand Up @@ -145,3 +146,24 @@ def test_regress_filter_weights_size(window_size, n_neurons_sender, n_neurons_re
f"match the third dimension of coupling_filters.")
assert weights.shape[2] == n_basis_funcs, (f"Third dimension of weights (n_basis_funcs) does not "
f"match the second dimension of eval_basis.")


def test_least_square_correctness():
"""
Test the correctness of the least square estimate by enforcing an invertible map,
i.e. a map for which the least-square estimator matches the original weights.
"""
# set up problem dimensionality
ws, n_neurons_receiver, n_neurons_sender, n_basis_funcs = 100, 1, 2, 10
# evaluate a basis
_, eval_basis = basis.RaisedCosineBasisLog(n_basis_funcs).evaluate_on_grid(ws)
# generate random weights to define filters
weights = np.random.normal(size=(n_neurons_receiver, n_neurons_sender, n_basis_funcs))
# define filters as linear combination of basis elements
coupling_filt = np.einsum("ijk, tk -> tij", weights, eval_basis)
# recover weights by means of linear regression
weights_lsq = simulation.regress_filter(coupling_filt, eval_basis)
# check the exact matching of the filters up to numerical error
assert np.allclose(weights_lsq, weights)


0 comments on commit 2a815b4

Please sign in to comment.