Skip to content

Commit

Permalink
Merge pull request #182 from flatironinstitute/group_lasso_docs
Browse files Browse the repository at this point in the history
improved docstrings
  • Loading branch information
BalzaniEdoardo authored Jul 9, 2024
2 parents e88ec9c + 0f7196f commit 23433ce
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 8 deletions.
58 changes: 51 additions & 7 deletions src/nemos/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,8 +561,8 @@ def _initialize_parameters(
(either as a FeaturePytree or ndarray, matching the structure of X) with shapes (n_features,).
- The second element is the initialized intercept (bias terms) as an ndarray of shape (1,).
Example
-------
Examples
--------
>>> import nemos as nmo
>>> import numpy as np
>>> X = np.zeros((100, 5)) # Example input
Expand Down Expand Up @@ -868,8 +868,8 @@ def initialize_solver(
- If `params` are not array-like when provided.
- If `init_params[i]` cannot be converted to jnp.ndarray for all i
Example
-------
Examples
--------
>>> X, y = load_data() # Hypothetical function to load data
>>> params, opt_state = model.initialize_solver(X, y)
>>> # Now ready to run optimization or update steps
Expand Down Expand Up @@ -951,8 +951,8 @@ def update(
indicating an invalid update step, typically due to numerical instabilities
or inappropriate solver configurations.
Example
-------
Examples
--------
>>> # Assume glm_instance is an instance of GLM that has been previously fitted.
>>> params = glm_instance.coef_, glm_instance.intercept_
>>> opt_state = glm_instance.solver_state
Expand Down Expand Up @@ -1005,7 +1005,8 @@ class PopulationGLM(GLM):
and related parameters.
Default is UnRegularized regression with gradient descent.
feature_mask :
Either a matrix of shape (num_features, num_neurons) or a [FeaturePytree](../pytrees) of 0s and 1s.
Either a matrix of shape (num_features, num_neurons) or a [FeaturePytree](../pytrees) of 0s and 1s, with
`feature_mask[feature_name]` of shape (num_neurons, ).
The mask will be used to select which features are used as predictors for which neuron.
Attributes
Expand All @@ -1023,6 +1024,49 @@ class PopulationGLM(GLM):
TypeError
- If provided `regularizer` or `observation_model` are not valid.
- If provided `feature_mask` is not an array-like of dimension two.
Examples
--------
>>> # Example with an array mask
>>> import jax.numpy as jnp
>>> import numpy as np
>>> from nemos.glm import PopulationGLM
>>> # Define predictors (X), weights, and neural activity (y)
>>> num_samples, num_features, num_neurons = 100, 3, 2
>>> X = np.random.normal(size=(num_samples, num_features))
>>> weights = np.array([[ 0.5, 0. ], [-0.5, -0.5], [ 0. , 1. ]])
>>> y = np.random.poisson(np.exp(X.dot(weights)))
>>> # Define a feature mask, shape (num_features, num_neurons)
>>> feature_mask = jnp.array([[1, 0], [1, 1], [0, 1]])
>>> print("Feature mask:")
>>> print(feature_mask)
>>> # Create and fit the model
>>> model = PopulationGLM(feature_mask=feature_mask)
>>> model.fit(X, y)
>>> # Check the fitted coefficients and intercepts
>>> print("Model coefficients:")
>>> print(model.coef_)
>>> # Example with a FeaturePytree mask
>>> from nemos.pytrees import FeaturePytree
>>> # Define two features
>>> feature_1 = np.random.normal(size=(num_samples, 2))
>>> feature_2 = np.random.normal(size=(num_samples, 1))
>>> # Define the FeaturePytree predictor, and weights
>>> X = FeaturePytree(feature_1=feature_1, feature_2=feature_2)
>>> weights = dict(feature_1=jnp.array([[0., 0.5], [0., -0.5]]), feature_2=jnp.array([[1., 0.]]))
>>> # Compute the firing rate and counts
>>> rate = np.exp(X["feature_1"].dot(weights["feature_1"]) + X["feature_2"].dot(weights["feature_2"]))
>>> y = np.random.poisson(rate)
>>> # Define a feature mask with arrays of shape (num_neurons, )
>>> feature_mask = FeaturePytree(feature_1=jnp.array([0, 1]), feature_2=jnp.array([1, 0]))
>>> print("Feature mask:")
>>> print(feature_mask)
>>> # Fit a PopulationGLM
>>> model = PopulationGLM(feature_mask=feature_mask)
>>> model.fit(X, y)
>>> print("Model coefficients:")
>>> print(model.coef_)
"""

def __init__(
Expand Down
26 changes: 25 additions & 1 deletion src/nemos/regularizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,10 +523,34 @@ class GroupLasso(ProxGradientRegularizer):
Attributes
----------
mask : Union[jnp.ndarray, NDArray]
A 2d mask array indicating groups of features for regularization.
A 2d mask array indicating groups of features for regularization, shape (num_groups, num_features).
Each row represents a group of features.
Each column corresponds to a feature, where a value of 1 indicates that the feature belongs
to the group, and a value of 0 indicates it doesn't.
Examples
--------
>>> import numpy as np
>>> from nemos.regularizer import GroupLasso # Assuming the module is named group_lasso
>>> from nemos.glm import GLM
>>> # simulate some counts
>>> num_samples, num_features, num_groups = 1000, 5, 3
>>> X = np.random.normal(size=(num_samples, num_features)) # design matrix
>>> w = [0, 0.5, 1, 0, -0.5] # define some weights
>>> y = np.random.poisson(np.exp(X.dot(w))) # observed counts
>>> # Define a mask for 3 groups and 5 features
>>> mask = np.zeros((num_groups, num_features))
>>> mask[0] = [1, 0, 0, 1, 0] # Group 0 includes features 0 and 3
>>> mask[1] = [0, 1, 0, 0, 0] # Group 1 includes features 1
>>> mask[2] = [0, 0, 1, 0, 1] # Group 2 includes features 2 and 4
>>> # Create the GroupLasso regularizer instance
>>> group_lasso = GroupLasso(solver_name='ProximalGradient', regularizer_strength=0.1, mask=mask)
>>> # fit a group-lasso glm
>>> model = GLM(regularizer=group_lasso).fit(X, y)
>>> print(f"coeff: {model.coef_}")
"""

def __init__(
Expand Down

0 comments on commit 23433ce

Please sign in to comment.