Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added glm.py examples in docstrings #249

Merged
merged 23 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/nemos/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def fit(self, X: FeatureMatrix, y=None):
-------
self :
The transformer object.

pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
"""
self._basis._set_kernel(*self._unpack_inputs(X))
return self
Expand Down
99 changes: 98 additions & 1 deletion src/nemos/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,25 @@ class GLM(BaseRegressor):
TypeError
If provided `regularizer` or `observation_model` are not valid.

Examples
--------
>>> import nemos as nmo

>>> # define simple GLM model
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> model = nmo.glm.GLM()
>>> print("Regularizer type: ", type(model.regularizer))
Regularizer type: <class 'nemos.regularizer.UnRegularized'>
>>> print("Observation model: ", type(model.observation_model))
Observation model: <class 'nemos.observation_models.PoissonObservations'>


>>> # define GLM model of PoissonObservations model with soft-plus NL
>>> observation_models = nmo.observation_models.PoissonObservations(jax.nn.softplus)
>>> model = nmo.glm.GLM(observation_model=observation_models, solver_name="LBFGS")
>>> print("Regularizer type: ", type(model.regularizer))
Regularizer type: <class 'nemos.regularizer.UnRegularized'>
>>> print("Observation model: ", type(model.observation_model))
Observation model: <class 'nemos.observation_models.PoissonObservations'>
"""

def __init__(
Expand Down Expand Up @@ -307,6 +326,21 @@ def predict(self, X: DESIGN_INPUT_TYPE) -> jnp.ndarray:
- If `X` is not three-dimensional.
- If there's an inconsistent number of features between spike basis coefficients and `X`.

Examples
--------
>>> # example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10)
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> Xnew = np.random.normal(size=(20, ) + X.shape[1:])

>>> # define and fit a GLM
>>> import nemos as nmo
>>> model = nmo.glm.GLM()
>>> model = model.fit(X, y)

>>> # predict new spike data
>>> predicted_spikes = model.predict(Xnew)
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved

See Also
--------
- [score](./#nemos.glm.GLM.score)
Expand Down Expand Up @@ -407,6 +441,20 @@ def score(
If X structure doesn't match the params, and if X and y have different
number of samples.

Examples
--------
>>> # example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10)
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved

>>> import nemos as nmo
>>> model = nmo.glm.GLM()
>>> model = model.fit(X, y)

>>> # get model score
>>> log_likelihood_score = model.score(X, y)
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> pseudo_r2_score = model.score(X, y, score_type='pseudo-r2-McFadden')
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved

Notes
-----
The log-likelihood is not on a standard scale, its value is influenced by many factors,
Expand Down Expand Up @@ -605,6 +653,20 @@ def fit(
- If `init_params` are not array-like
- If `init_params[i]` cannot be converted to jnp.ndarray for all i

Examples
-------
>>> # example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10)
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved

>>> # fit a ridge regression Poisson GLM
>>> import nemos as nmo
>>> model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1)
>>> model = model.fit(X, y)

>>> # get model weights
>>> model_weights = model.coef_
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved

"""
# validate the inputs & initialize solver
init_params = self.initialize_params(X, y, init_params=init_params)
Expand Down Expand Up @@ -686,7 +748,7 @@ def simulate(
Returns
-------
simulated_activity :
Simulated activity (spike counts for PoissonGLMs) for the neuron over time.
Simulated activity (spike counts for Poisson GLMs) for the neuron over time.
Shape: (n_time_bins, ).
firing_rates :
Simulated rates for the neuron over time. Shape, (n_time_bins, ).
Expand All @@ -698,6 +760,21 @@ def simulate(
ValueError
- If the instance has not been previously fitted.

Examples
--------
>>> # example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10)
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> Xnew = np.random.normal(size=(20, ) + X.shape[1:])

>>> # define and fit model
>>> import nemos as nmo
>>> model = nmo.glm.GLM()
>>> model = model.fit(X, y)

>>> # generate spikes and rates
>>> random_key = jax.random.key(123)
>>> spikes, rates = model.simulate(random_key, Xnew)
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved

See Also
--------
Expand Down Expand Up @@ -1376,6 +1453,26 @@ def fit(
- If the mask is a `FeaturePytree`, then `"feature_name"` is a predictor of neuron `j` if
`feature_mask["feature_name"][j] == 1`.

Examples
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
--------
>>> # Generate sample data
>>> import jax.numpy as jnp
>>> import numpy as np
>>> from nemos.glm import PopulationGLM

>>> # Define predictors (X), weights, and neural activity (y)
>>> num_samples, num_features, num_neurons = 100, 3, 2
>>> X = np.random.normal(size=(num_samples, num_features))
>>> weights = np.array([[ 0.5, 0. ], [-0.5, -0.5], [ 0. , 1. ]])
>>> y = np.random.poisson(np.exp(X.dot(weights)))

>>> # Define a feature mask, shape (num_features, num_neurons)
>>> feature_mask = jnp.array([[1, 0], [1, 1], [0, 1]])

>>> # Create and fit the model
>>> model = PopulationGLM(feature_mask=feature_mask).fit(X, y)
>>> print(model.coef_.shape)
(3, 2)
"""
return super().fit(X, y, init_params)

Expand Down
Loading