diff --git a/docs/tutorials/plot_06_calcium_imaging.py b/docs/tutorials/plot_06_calcium_imaging.py index bbd167ee..54d83666 100644 --- a/docs/tutorials/plot_06_calcium_imaging.py +++ b/docs/tutorials/plot_06_calcium_imaging.py @@ -28,24 +28,6 @@ nap.nap_config.suppress_conversion_warnings = True warnings.filterwarnings("ignore", category=UserWarning, message="The feature matrix is not of dtype") -# %% -# Later in the tutorial, we will compute the rank of the GLM feature matrix. -# To ensure the computation accounts for the model's intercept term, -# we can add a constant column of ones before calculating the rank. -# Below is a utility function for adding the intercept column. - -def add_intercept(X): - """Add an intercept term to design matrix. - - Convert matrix to float64, drops nans and add intercept term. - """ - # convert to float64 for rank computation precision - X = np.asarray(X, dtype=np.float64) - # drop nans - X = X[nmo.tree_utils.get_valid_multitree(X)] - return np.hstack([np.ones((X.shape[0], 1)), X]) - - # %% @@ -201,48 +183,41 @@ def add_intercept(X): X = basis.compute_features(head_direction, Y[:, selected_neurons]) # %% -# A design matrix $X$ including a CyclicBSpline and an intercept term is rank deficient, i.e. one can find non-zero -# coefficients $\mathbf{w} \neq \mathbf{0}$ such that $X \cdot \mathbf{w} = \mathbf{0}$. - -print(f"Number of features: {X.shape[1] + 1}") # num coefficients + intercept -print(f"Matrix rank: {np.linalg.matrix_rank(add_intercept(X))}") - -# %% -# By setting, - -w = np.ones((X.shape[1] + 1)) -w[0] = -1 -w[1 + heading_basis.n_basis_funcs:] = 0 - - -# %% -# We have that, - -np.max(np.abs(np.dot(add_intercept(X), w))) - -# %% -# This implies that there will be infinite different parameters that results in the same firing rate, -# or equivalently there will be infinite many equivalent solutions to an un-regularized GLM. - -# define some random coefficients -coef = np.random.randn(X.shape[1] + 1) - -# the firing rate is softplus([1, X] * coef) -# adding w to the coefficients does not change the output rate. -firing_rate = jax.nn.softplus(np.dot(add_intercept(X), coef)) -firing_rate_2 = jax.nn.softplus(np.dot(add_intercept(X), coef + w)) - -# check that the rate match -np.allclose(firing_rate, firing_rate_2) - -# %% -# We can avoid this issue by dropping linearly dependent columns in the design matrix. - -X, idx = apply_identifiability_constraints_by_basis_component(basis, X) - -print(f"Number of features: {X.shape[1] + 1}") # drops one column -print(f"Matrix rank: {np.linalg.matrix_rank(add_intercept(X))}") - +# +# Before we use this design matrix to fit the population, we need to take a brief detour +# into linear algebra. Depending on your design matrix is constructed, it is likely to +# be rank-deficient, in which case it has a null space. Practically, that means that +# there are infinitely many different sets of parameters that predict the same firing +# rate. If you want to interpret your parameters, this is bad! +# +# While this multiplicity of solutions is always a potential issue when fitting models, +# it is particularly relevant when using basis objects in nemos, as many of our basis +# sets completely tile the input space (i.e., summing across all $n$ basis functions +# returns 1 everywhere), which, when combined with the intercept term always present in +# the GLM (i.e., the base firing rate), will give you a rank-deficient matrix. +# +# We thus recommend that you always check the rank of your design matrix and provide +# some tools to drop the linearly-dependent columns, if necessary, which will guarantee +# that your design matrix is full rank and thus that there is one unique solution. +# +# !!! tip "Linear Algebra" +# +# To read more about matrix rank, see +# [Wikipedia](https://en.wikipedia.org/wiki/Rank_(linear_algebra)#Main_definitions). +# Gil Strang's Linear Algebra course, [available for free +# online](https://web.mit.edu/18.06/www/), and [NeuroMatch +# Academy](https://compneuro.neuromatch.io/tutorials/W0D3_LinearAlgebra/student/W0D3_Tutorial2.html) +# are also great resources. +# +# In this case, we are using the CyclicBSpline basis functions, which uniformly tile and +# thus will result in a rank-deficient matrix. Therefore, we will use a utility function +# to drop a column from the matrix and make it full-rank: + +# The number of features is the number of columns plus one (for the intercept) +print(f"Number of features in the rank-deficient design matrix: {X.shape[1] + 1}") +X, _ = apply_identifiability_constraints_by_basis_component(basis, X) +# We have dropped one column +print(f"Number of features in the full-rank design matrix: {X.shape[1] + 1}") # %% # ## Train & test set