From ff6d456c188c3bf760fb61423347802af0b8b951 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 8 Jul 2024 17:31:58 -0400 Subject: [PATCH 1/6] improved docstrings --- src/nemos/glm.py | 12 ++++++------ src/nemos/regularizer.py | 26 +++++++++++++++++++++++++- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/src/nemos/glm.py b/src/nemos/glm.py index 9d0b7857..70a9edfc 100644 --- a/src/nemos/glm.py +++ b/src/nemos/glm.py @@ -561,8 +561,8 @@ def _initialize_parameters( (either as a FeaturePytree or ndarray, matching the structure of X) with shapes (n_features,). - The second element is the initialized intercept (bias terms) as an ndarray of shape (1,). - Example - ------- + Examples + -------- >>> import nemos as nmo >>> import numpy as np >>> X = np.zeros((100, 5)) # Example input @@ -868,8 +868,8 @@ def initialize_solver( - If `params` are not array-like when provided. - If `init_params[i]` cannot be converted to jnp.ndarray for all i - Example - ------- + Examples + -------- >>> X, y = load_data() # Hypothetical function to load data >>> params, opt_state = model.initialize_solver(X, y) >>> # Now ready to run optimization or update steps @@ -951,8 +951,8 @@ def update( indicating an invalid update step, typically due to numerical instabilities or inappropriate solver configurations. - Example - ------- + Examples + -------- >>> # Assume glm_instance is an instance of GLM that has been previously fitted. >>> params = glm_instance.coef_, glm_instance.intercept_ >>> opt_state = glm_instance.solver_state diff --git a/src/nemos/regularizer.py b/src/nemos/regularizer.py index e9b8b1b6..8863be5b 100644 --- a/src/nemos/regularizer.py +++ b/src/nemos/regularizer.py @@ -523,10 +523,34 @@ class GroupLasso(ProxGradientRegularizer): Attributes ---------- mask : Union[jnp.ndarray, NDArray] - A 2d mask array indicating groups of features for regularization. + A 2d mask array indicating groups of features for regularization, shape (num_groups, num_features). Each row represents a group of features. Each column corresponds to a feature, where a value of 1 indicates that the feature belongs to the group, and a value of 0 indicates it doesn't. + + Examples + -------- + >>> import numpy as np + >>> from nemos.regularizer import GroupLasso # Assuming the module is named group_lasso + >>> from nemos.glm import GLM + + >>> # simulate some counts + >>> num_samples, num_features, num_groups = 1000, 5, 3 + >>> X = np.random.normal(size=(num_samples, num_features)) # design matrix + >>> w = [0, 0.5, 1, 0, -0.5] # define some weights + >>> y = np.random.poisson(np.exp(X.dot(w))) # observed counts + + >>> # Define a mask for 3 groups and 5 features + >>> mask = np.zeros((num_groups, num_features)) + >>> mask[0] = [1, 0, 0, 1, 0] # # Group 0 includes features 0 and 3 + >>> mask[1] = [0, 1, 0, 0, 0] # # Group 1 includes features 1 + >>> mask[2] = [0, 0, 1, 0, 1] # # Group 2 includes features 2 and 4 + + >>> # Create the GroupLasso regularizer instance + >>> group_lasso = GroupLasso(solver_name='ProximalGradient', regularizer_strength=0.1, mask=mask) + >>> # fit a group-lasso glm + >>> model = GLM(regularizer=group_lasso).fit(X, y) + >>> print(f"coeff: {model.coef_}") """ def __init__( From 268566864d3641f5425d93445c4d142a168c4063 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 9 Jul 2024 09:27:35 -0400 Subject: [PATCH 2/6] added an example --- src/nemos/glm.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/nemos/glm.py b/src/nemos/glm.py index 70a9edfc..d8a20188 100644 --- a/src/nemos/glm.py +++ b/src/nemos/glm.py @@ -1023,6 +1023,31 @@ class PopulationGLM(GLM): TypeError - If provided `regularizer` or `observation_model` are not valid. - If provided `feature_mask` is not an array-like of dimension two. + + Examples + -------- + >>> import jax.numpy as jnp + >>> import numpy as np + >>> from nemos.glm import PopulationGLM + + >>> # Define predictors (X), weights, and neural activity (y) + >>> num_samples, num_features, num_neurons = 100, 3, 2 + >>> X = np.random.normal(size=(num_samples, num_features)) + >>> weights = np.array([[ 0.5, 0. ], [-0.5, -0.5], [ 0. , 1. ]]) + >>> y = np.random.poisson(np.exp(X.dot(weights))) + + >>> # Define a feature mask, shape (num_features, num_neurons) + >>> feature_mask = jnp.array([[1, 0], [1, 1], [0, 1]]) + >>> print("Feature mask:") + >>> print(feature_mask) + + >>> # Create and fit the model + >>> model = PopulationGLM(feature_mask=feature_mask) + >>> model.fit(X, y) + + >>> # Check the fitted coefficients and intercepts + >>> print("Model coefficients:") + >>> print(model.coef_) """ def __init__( From 5acfdbf48a21433183db8f5df018f695b3d7014c Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 9 Jul 2024 09:30:46 -0400 Subject: [PATCH 3/6] removed double comments --- src/nemos/regularizer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/nemos/regularizer.py b/src/nemos/regularizer.py index 8863be5b..04820c9d 100644 --- a/src/nemos/regularizer.py +++ b/src/nemos/regularizer.py @@ -542,9 +542,9 @@ class GroupLasso(ProxGradientRegularizer): >>> # Define a mask for 3 groups and 5 features >>> mask = np.zeros((num_groups, num_features)) - >>> mask[0] = [1, 0, 0, 1, 0] # # Group 0 includes features 0 and 3 - >>> mask[1] = [0, 1, 0, 0, 0] # # Group 1 includes features 1 - >>> mask[2] = [0, 0, 1, 0, 1] # # Group 2 includes features 2 and 4 + >>> mask[0] = [1, 0, 0, 1, 0] # Group 0 includes features 0 and 3 + >>> mask[1] = [0, 1, 0, 0, 0] # Group 1 includes features 1 + >>> mask[2] = [0, 0, 1, 0, 1] # Group 2 includes features 2 and 4 >>> # Create the GroupLasso regularizer instance >>> group_lasso = GroupLasso(solver_name='ProximalGradient', regularizer_strength=0.1, mask=mask) From 6249ced025286a2ad824d4abd51443712186824a Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 9 Jul 2024 09:54:10 -0400 Subject: [PATCH 4/6] added a featurepytree example --- src/nemos/glm.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/src/nemos/glm.py b/src/nemos/glm.py index d8a20188..2ed3ab40 100644 --- a/src/nemos/glm.py +++ b/src/nemos/glm.py @@ -1026,28 +1026,44 @@ class PopulationGLM(GLM): Examples -------- + >>> # Example with an array mask >>> import jax.numpy as jnp >>> import numpy as np >>> from nemos.glm import PopulationGLM - >>> # Define predictors (X), weights, and neural activity (y) >>> num_samples, num_features, num_neurons = 100, 3, 2 >>> X = np.random.normal(size=(num_samples, num_features)) >>> weights = np.array([[ 0.5, 0. ], [-0.5, -0.5], [ 0. , 1. ]]) >>> y = np.random.poisson(np.exp(X.dot(weights))) - >>> # Define a feature mask, shape (num_features, num_neurons) >>> feature_mask = jnp.array([[1, 0], [1, 1], [0, 1]]) >>> print("Feature mask:") >>> print(feature_mask) - >>> # Create and fit the model >>> model = PopulationGLM(feature_mask=feature_mask) >>> model.fit(X, y) - >>> # Check the fitted coefficients and intercepts >>> print("Model coefficients:") >>> print(model.coef_) + + >>> # Example with a FeaturePytree mask + >>> from nemos.pytrees import FeaturePytree + >>> from nemos.tree_utils import pytree_map_and_reduce + >>> # Define two features + >>> feature_1 = np.random.normal(size=(num_samples, 2)) + >>> feature_2 = np.random.normal(size=(num_samples, 1)) + >>> # Define the FeaturePytree predictor, and weights + >>> X = FeaturePytree(feature_1=feature_1, feature_2=feature_2) + >>> weights = dict(feature_1=jnp.array([[0., 0.5], [0., -0.5]]), feature_2=jnp.array([[1., 0.]])) + >>> # Compute the firing rate and counts + >>> rate = np.exp(X["feature_1"].dot(weights["feature_1"]) + X["feature_2"].dot(weights["feature_2"])) + >>> y = np.random.poisson(rate) + >>> # define a feature mask such that + >>> # feature_1 is a predictor for the 2nd neuron and feature_2 for the 1st + >>> feature_mask = FeaturePytree(feature_1=jnp.array([0, 1]), feature_2=jnp.array([1, 0])) + >>> model = PopulationGLM(feature_mask=feature_mask) + >>> model.fit(X, y) + >>> print(model.coef_) """ def __init__( From ab1695b63b2a7944c96c65d98501ded05bdea152 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 9 Jul 2024 09:55:55 -0400 Subject: [PATCH 5/6] improved prints --- src/nemos/glm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/nemos/glm.py b/src/nemos/glm.py index 2ed3ab40..d08de24f 100644 --- a/src/nemos/glm.py +++ b/src/nemos/glm.py @@ -1048,7 +1048,6 @@ class PopulationGLM(GLM): >>> # Example with a FeaturePytree mask >>> from nemos.pytrees import FeaturePytree - >>> from nemos.tree_utils import pytree_map_and_reduce >>> # Define two features >>> feature_1 = np.random.normal(size=(num_samples, 2)) >>> feature_2 = np.random.normal(size=(num_samples, 1)) @@ -1061,8 +1060,11 @@ class PopulationGLM(GLM): >>> # define a feature mask such that >>> # feature_1 is a predictor for the 2nd neuron and feature_2 for the 1st >>> feature_mask = FeaturePytree(feature_1=jnp.array([0, 1]), feature_2=jnp.array([1, 0])) + >>> print("Feature mask:") + >>> print(feature_mask) >>> model = PopulationGLM(feature_mask=feature_mask) >>> model.fit(X, y) + >>> print("Model coefficients:") >>> print(model.coef_) """ From 0f7196ffb9e7bf14424475c8dd3aa83ce9b9b17a Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 9 Jul 2024 09:59:48 -0400 Subject: [PATCH 6/6] improved text --- src/nemos/glm.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/nemos/glm.py b/src/nemos/glm.py index d08de24f..4a74f208 100644 --- a/src/nemos/glm.py +++ b/src/nemos/glm.py @@ -1005,7 +1005,8 @@ class PopulationGLM(GLM): and related parameters. Default is UnRegularized regression with gradient descent. feature_mask : - Either a matrix of shape (num_features, num_neurons) or a [FeaturePytree](../pytrees) of 0s and 1s. + Either a matrix of shape (num_features, num_neurons) or a [FeaturePytree](../pytrees) of 0s and 1s, with + `feature_mask[feature_name]` of shape (num_neurons, ). The mask will be used to select which features are used as predictors for which neuron. Attributes @@ -1057,11 +1058,11 @@ class PopulationGLM(GLM): >>> # Compute the firing rate and counts >>> rate = np.exp(X["feature_1"].dot(weights["feature_1"]) + X["feature_2"].dot(weights["feature_2"])) >>> y = np.random.poisson(rate) - >>> # define a feature mask such that - >>> # feature_1 is a predictor for the 2nd neuron and feature_2 for the 1st + >>> # Define a feature mask with arrays of shape (num_neurons, ) >>> feature_mask = FeaturePytree(feature_1=jnp.array([0, 1]), feature_2=jnp.array([1, 0])) >>> print("Feature mask:") >>> print(feature_mask) + >>> # Fit a PopulationGLM >>> model = PopulationGLM(feature_mask=feature_mask) >>> model.fit(X, y) >>> print("Model coefficients:")