From 4315ce3ba991cdd9113571bef29b79559e337b78 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 5 Nov 2024 10:24:34 -0500 Subject: [PATCH] rank def added text --- docs/tutorials/plot_06_calcium_imaging.py | 36 ++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/docs/tutorials/plot_06_calcium_imaging.py b/docs/tutorials/plot_06_calcium_imaging.py index 4772f69a..cbe1f192 100644 --- a/docs/tutorials/plot_06_calcium_imaging.py +++ b/docs/tutorials/plot_06_calcium_imaging.py @@ -28,6 +28,12 @@ nap.nap_config.suppress_conversion_warnings = True warnings.filterwarnings("ignore", category=UserWarning, message="The feature matrix is not of dtype") +# %% +# Later in the tutorial, we will compute the rank of the GLM feature matrix. +# To ensure the computation accounts for the model's intercept term, +# we can add a constant column of ones before calculating the rank. +# Below is a utility function for adding the intercept column. + def add_intercept(X): """Add an intercept term to design matrix. @@ -195,13 +201,41 @@ def add_intercept(X): X = basis.compute_features(head_direction, Y[:, selected_neurons]) # %% -# A design matrix including a CyclicBSpline and an intercept term is rank deficient. +# A design matrix $X$ including a CyclicBSpline and an intercept term is rank deficient, i.e. one can find non-zero +# coefficients $\mathbf{w} \neq \mathbf{0}$ such that $X \cdot \mathbf{w} = \mathbf{0}$. print(f"Number of features: {X.shape[1] + 1}") # num coefficients + intercept print(f"Matrix rank: {np.linalg.matrix_rank(add_intercept(X))}") +# %% +# By setting, + +w = np.ones((X.shape[1] + 1)) +w[0] = -1 +w[1 + heading_basis.n_basis_funcs:] = 0 + + +# %% +# We have that, + +np.max(np.abs(np.dot(add_intercept(X), w))) + +# %% # This implies that there will be infinite different parameters that results in the same firing rate, # or equivalently there will be infinite many equivalent solutions to an un-regularized GLM. + +# define some random coefficients +coef = np.random.randn(X.shape[1] + 1) + +# the firing rate is softplus([1, X] * coef) +# adding w to the coefficients does not change the output rate. +firing_rate = jax.nn.softplus(np.dot(add_intercept(X), coef)) +firing_rate_2 = jax.nn.softplus(np.dot(add_intercept(X), coef + w)) + +# check that the rate match +np.allclose(firing_rate, firing_rate_2) + +# %% # We can avoid this issue by dropping linearly dependent columns in the design matrix. X, idx = apply_identifiability_constraints_by_basis_component(basis, X)