diff --git a/src/nemos/glm.py b/src/nemos/glm.py index 03684859..74cbdee5 100644 --- a/src/nemos/glm.py +++ b/src/nemos/glm.py @@ -104,6 +104,25 @@ class GLM(BaseRegressor): TypeError If provided `regularizer` or `observation_model` are not valid. + Examples + -------- + >>> import nemos as nmo + + >>> # define single neuron GLM model + >>> model = nmo.glm.GLM() + >>> print("Regularizer type: ", type(model.regularizer)) + Regularizer type: + >>> print("Observation model: ", type(model.observation_model)) + Observation model: + + + >>> # define GLM model of PoissonObservations model with soft-plus NL + >>> observation_models = nmo.observation_models.PoissonObservations(jax.nn.softplus) + >>> model = nmo.glm.GLM(observation_model=observation_models, solver_name="LBFGS") + >>> print("Regularizer type: ", type(model.regularizer)) + Regularizer type: + >>> print("Observation model: ", type(model.observation_model)) + Observation model: """ def __init__( @@ -307,6 +326,21 @@ def predict(self, X: DESIGN_INPUT_TYPE) -> jnp.ndarray: - If `X` is not three-dimensional. - If there's an inconsistent number of features between spike basis coefficients and `X`. + Examples + -------- + >>> # example input + >>> import numpy as np + >>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10) + + >>> # define and fit a GLM + >>> import nemos as nmo + >>> model = nmo.glm.GLM() + >>> model = model.fit(X, y) + + >>> # predict new spike data + >>> Xnew = np.random.normal(size=(20, X.shape[1])) + >>> predicted_spikes = model.predict(Xnew) + See Also -------- - [score](./#nemos.glm.GLM.score) @@ -407,6 +441,22 @@ def score( If X structure doesn't match the params, and if X and y have different number of samples. + Examples + -------- + >>> # example input + >>> import numpy as np + >>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10) + + >>> import nemos as nmo + >>> model = nmo.glm.GLM() + >>> model = model.fit(X, y) + + >>> # get model score + >>> log_likelihood_score = model.score(X, y) + + >>> # get a pseudo-R2 score + >>> pseudo_r2_score = model.score(X, y, score_type='pseudo-r2-McFadden') + Notes ----- The log-likelihood is not on a standard scale, its value is influenced by many factors, @@ -605,6 +655,21 @@ def fit( - If `init_params` are not array-like - If `init_params[i]` cannot be converted to jnp.ndarray for all i + Examples + ------- + >>> # example input + >>> import numpy as np + >>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10) + + >>> # fit a ridge regression Poisson GLM + >>> import nemos as nmo + >>> model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1) + >>> model = model.fit(X, y) + + >>> # get model weights and intercept + >>> model_weights = model.coef_ + >>> model_intercept = model.intercept_ + """ # validate the inputs & initialize solver init_params = self.initialize_params(X, y, init_params=init_params) @@ -686,7 +751,7 @@ def simulate( Returns ------- simulated_activity : - Simulated activity (spike counts for PoissonGLMs) for the neuron over time. + Simulated activity (spike counts for Poisson GLMs) for the neuron over time. Shape: (n_time_bins, ). firing_rates : Simulated rates for the neuron over time. Shape, (n_time_bins, ). @@ -698,6 +763,21 @@ def simulate( ValueError - If the instance has not been previously fitted. + Examples + -------- + >>> # example input + >>> import numpy as np + >>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10) + + >>> # define and fit model + >>> import nemos as nmo + >>> model = nmo.glm.GLM() + >>> model = model.fit(X, y) + + >>> # generate spikes and rates + >>> random_key = jax.random.key(123) + >>> Xnew = np.random.normal(size=(20, X.shape[1])) + >>> spikes, rates = model.simulate(random_key, Xnew) See Also -------- @@ -871,6 +951,16 @@ def initialize_state( ------- NamedTuple The initialized solver state + + Examples + -------- + >>> import numpy as np + >>> import nemos as nmo + >>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10) + >>> model = nmo.glm.GLM() + >>> params = model.initialize_params(X, y) + >>> opt_state = model.initialize_state(X, y, params) + >>> # Now ready to run optimization or update steps """ if isinstance(X, FeaturePytree): data = X.data @@ -1376,6 +1466,28 @@ def fit( - If the mask is a `FeaturePytree`, then `"feature_name"` is a predictor of neuron `j` if `feature_mask["feature_name"][j] == 1`. + Examples + -------- + >>> # Generate sample data + >>> import jax.numpy as jnp + >>> import numpy as np + >>> from nemos.glm import PopulationGLM + + >>> # Define predictors (X), weights, and neural activity (y) + >>> num_samples, num_features, num_neurons = 100, 3, 2 + >>> X = np.random.normal(size=(num_samples, num_features)) + >>> # Weights is defined by how each feature influences the output, shape (num_features, num_neurons) + >>> weights = np.array([[ 0.5, 0. ], [-0.5, -0.5], [ 0. , 1. ]]) + >>> # Output y simulates a Poisson distribution based on a linear model between features X and wegihts + >>> y = np.random.poisson(np.exp(X.dot(weights))) + + >>> # Define a feature mask, shape (num_features, num_neurons) + >>> feature_mask = jnp.array([[1, 0], [1, 1], [0, 1]]) + + >>> # Create and fit the model + >>> model = PopulationGLM(feature_mask=feature_mask).fit(X, y) + >>> print(model.coef_.shape) + (3, 2) """ return super().fit(X, y, init_params)