Skip to content

Commit

Permalink
rank def added text
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Nov 5, 2024
1 parent 2f5481d commit 4315ce3
Showing 1 changed file with 35 additions and 1 deletion.
36 changes: 35 additions & 1 deletion docs/tutorials/plot_06_calcium_imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@
nap.nap_config.suppress_conversion_warnings = True
warnings.filterwarnings("ignore", category=UserWarning, message="The feature matrix is not of dtype")

# %%
# Later in the tutorial, we will compute the rank of the GLM feature matrix.
# To ensure the computation accounts for the model's intercept term,
# we can add a constant column of ones before calculating the rank.
# Below is a utility function for adding the intercept column.

def add_intercept(X):
"""Add an intercept term to design matrix.
Expand Down Expand Up @@ -195,13 +201,41 @@ def add_intercept(X):
X = basis.compute_features(head_direction, Y[:, selected_neurons])

# %%
# A design matrix including a CyclicBSpline and an intercept term is rank deficient.
# A design matrix $X$ including a CyclicBSpline and an intercept term is rank deficient, i.e. one can find non-zero
# coefficients $\mathbf{w} \neq \mathbf{0}$ such that $X \cdot \mathbf{w} = \mathbf{0}$.

print(f"Number of features: {X.shape[1] + 1}") # num coefficients + intercept
print(f"Matrix rank: {np.linalg.matrix_rank(add_intercept(X))}")

# %%
# By setting,

w = np.ones((X.shape[1] + 1))
w[0] = -1
w[1 + heading_basis.n_basis_funcs:] = 0


# %%
# We have that,

np.max(np.abs(np.dot(add_intercept(X), w)))

# %%
# This implies that there will be infinite different parameters that results in the same firing rate,
# or equivalently there will be infinite many equivalent solutions to an un-regularized GLM.

# define some random coefficients
coef = np.random.randn(X.shape[1] + 1)

# the firing rate is softplus([1, X] * coef)
# adding w to the coefficients does not change the output rate.
firing_rate = jax.nn.softplus(np.dot(add_intercept(X), coef))
firing_rate_2 = jax.nn.softplus(np.dot(add_intercept(X), coef + w))

# check that the rate match
np.allclose(firing_rate, firing_rate_2)

# %%
# We can avoid this issue by dropping linearly dependent columns in the design matrix.

X, idx = apply_identifiability_constraints_by_basis_component(basis, X)
Expand Down

0 comments on commit 4315ce3

Please sign in to comment.