Skip to content

Commit

Permalink
Merge branch 'coeff_parsing_support' of github.com:flatironinstitute/…
Browse files Browse the repository at this point in the history
…nemos into coeff_parsing_support
  • Loading branch information
BalzaniEdoardo committed Nov 5, 2024
2 parents 44fca2e + 19ee425 commit d84e003
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 61 deletions.
95 changes: 35 additions & 60 deletions docs/tutorials/plot_06_calcium_imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,6 @@
nap.nap_config.suppress_conversion_warnings = True
warnings.filterwarnings("ignore", category=UserWarning, message="The feature matrix is not of dtype")

# %%
# Later in the tutorial, we will compute the rank of the GLM feature matrix.
# To ensure the computation accounts for the model's intercept term,
# we can add a constant column of ones before calculating the rank.
# Below is a utility function for adding the intercept column.

def add_intercept(X):
"""Add an intercept term to design matrix.
Convert matrix to float64, drops nans and add intercept term.
"""
# convert to float64 for rank computation precision
X = np.asarray(X, dtype=np.float64)
# drop nans
X = X[nmo.tree_utils.get_valid_multitree(X)]
return np.hstack([np.ones((X.shape[0], 1)), X])




# %%
Expand Down Expand Up @@ -201,48 +183,41 @@ def add_intercept(X):
X = basis.compute_features(head_direction, Y[:, selected_neurons])

# %%
# A design matrix $X$ including a CyclicBSpline and an intercept term is rank deficient, i.e. one can find non-zero
# coefficients $\mathbf{w} \neq \mathbf{0}$ such that $X \cdot \mathbf{w} = \mathbf{0}$.

print(f"Number of features: {X.shape[1] + 1}") # num coefficients + intercept
print(f"Matrix rank: {np.linalg.matrix_rank(add_intercept(X))}")

# %%
# By setting,

w = np.ones((X.shape[1] + 1))
w[0] = -1
w[1 + heading_basis.n_basis_funcs:] = 0


# %%
# We have that,

np.max(np.abs(np.dot(add_intercept(X), w)))

# %%
# This implies that there will be infinite different parameters that results in the same firing rate,
# or equivalently there will be infinite many equivalent solutions to an un-regularized GLM.

# define some random coefficients
coef = np.random.randn(X.shape[1] + 1)

# the firing rate is softplus([1, X] * coef)
# adding w to the coefficients does not change the output rate.
firing_rate = jax.nn.softplus(np.dot(add_intercept(X), coef))
firing_rate_2 = jax.nn.softplus(np.dot(add_intercept(X), coef + w))

# check that the rate match
np.allclose(firing_rate, firing_rate_2)

# %%
# We can avoid this issue by dropping linearly dependent columns in the design matrix.

X, idx = apply_identifiability_constraints_by_basis_component(basis, X)

print(f"Number of features: {X.shape[1] + 1}") # drops one column
print(f"Matrix rank: {np.linalg.matrix_rank(add_intercept(X))}")

#
# Before we use this design matrix to fit the population, we need to take a brief detour
# into linear algebra. Depending on your design matrix is constructed, it is likely to
# be rank-deficient, in which case it has a null space. Practically, that means that
# there are infinitely many different sets of parameters that predict the same firing
# rate. If you want to interpret your parameters, this is bad!
#
# While this multiplicity of solutions is always a potential issue when fitting models,
# it is particularly relevant when using basis objects in nemos, as many of our basis
# sets completely tile the input space (i.e., summing across all $n$ basis functions
# returns 1 everywhere), which, when combined with the intercept term always present in
# the GLM (i.e., the base firing rate), will give you a rank-deficient matrix.
#
# We thus recommend that you always check the rank of your design matrix and provide
# some tools to drop the linearly-dependent columns, if necessary, which will guarantee
# that your design matrix is full rank and thus that there is one unique solution.
#
# !!! tip "Linear Algebra"
#
# To read more about matrix rank, see
# [Wikipedia](https://en.wikipedia.org/wiki/Rank_(linear_algebra)#Main_definitions).
# Gil Strang's Linear Algebra course, [available for free
# online](https://web.mit.edu/18.06/www/), and [NeuroMatch
# Academy](https://compneuro.neuromatch.io/tutorials/W0D3_LinearAlgebra/student/W0D3_Tutorial2.html)
# are also great resources.
#
# In this case, we are using the CyclicBSpline basis functions, which uniformly tile and
# thus will result in a rank-deficient matrix. Therefore, we will use a utility function
# to drop a column from the matrix and make it full-rank:

# The number of features is the number of columns plus one (for the intercept)
print(f"Number of features in the rank-deficient design matrix: {X.shape[1] + 1}")
X, _ = apply_identifiability_constraints_by_basis_component(basis, X)
# We have dropped one column
print(f"Number of features in the full-rank design matrix: {X.shape[1] + 1}")

# %%
# ## Train & test set
Expand Down
2 changes: 1 addition & 1 deletion src/nemos/identifiability_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _add_invalid_entries(feature_matrix, shape_first_axis, is_valid):
def _apply_identifiability_constraints(
feature_matrix: JaxArray,
preprocessing_func: Callable = add_constant,
warn_if_float32=True,
warn_if_float32: bool = True,
) -> Tuple[JaxArray, JaxArray]:
"""
Apply identifiability constraints to a design matrix `feature_matrix`.
Expand Down

0 comments on commit d84e003

Please sign in to comment.