From d9c270978ed9b6c9e45b825f3b63a704738ce4f4 Mon Sep 17 00:00:00 2001 From: Pranati Modumudi Date: Tue, 1 Oct 2024 22:52:49 -0400 Subject: [PATCH 01/18] first 2 functions docs --- src/nemos/glm.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/nemos/glm.py b/src/nemos/glm.py index 5c65ef81..539e0bbe 100644 --- a/src/nemos/glm.py +++ b/src/nemos/glm.py @@ -104,6 +104,18 @@ class GLM(BaseRegressor): TypeError If provided `regularizer` or `observation_model` are not valid. + Examples + -------- + >>> from nemos.glm import GLM + # define simple GLM model + >>> model = nmo.glm.GLM() + >>> print("Regularization type: ", type(model.regularizer)) + >>> print("Observation model: ", type(model.observation_model)) + + # define GLM model of PoissonObservations model with soft-plus NL + >>> observation_models = nmo.observation_models.PoissonObservations(jax.nn.softplus) + >>> model = nmo.glm.GLM(observation_model=observation_models, \ + ... solver_name="LBFGS") """ def __init__( From 1e7b03cdc4464d25c0feefc40958403162349d2f Mon Sep 17 00:00:00 2001 From: Pranati Modumudi Date: Tue, 1 Oct 2024 22:53:18 -0400 Subject: [PATCH 02/18] first 2 functions docs --- src/nemos/glm.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/src/nemos/glm.py b/src/nemos/glm.py index 539e0bbe..06a5bcf5 100644 --- a/src/nemos/glm.py +++ b/src/nemos/glm.py @@ -107,6 +107,7 @@ class GLM(BaseRegressor): Examples -------- >>> from nemos.glm import GLM + # define simple GLM model >>> model = nmo.glm.GLM() >>> print("Regularization type: ", type(model.regularizer)) @@ -319,6 +320,11 @@ def predict(self, X: DESIGN_INPUT_TYPE) -> jnp.ndarray: - If `X` is not three-dimensional. - If there's an inconsistent number of features between spike basis coefficients and `X`. + Examples + -------- + >>> model = nmo.glm.GLM() + >>> model.fit(X, y) + See Also -------- - [score](./#nemos.glm.GLM.score) @@ -419,6 +425,11 @@ def score( If X structure doesn't match the params, and if X and y have different number of samples. + Examples + -------- + + >>> + Notes ----- The log-likelihood is not on a standard scale, its value is influenced by many factors, @@ -617,6 +628,33 @@ def fit( - If `init_params` are not array-like - If `init_params[i]` cannot be converted to jnp.ndarray for all i + Examples + ------- + + # fit a ridge regression Poisson GLM + >>> import nemos as nmo + # random design tensor. Shape (n_time_points, n_features). + >>> X = 0.5*np.random.normal(size=(100, 5)) + + # set log-rates & weights, shape (1, ) and (n_features, ) respectively. + >>> b_true = np.zeros((1, )) + >>> w_true = np.random.normal(size=(5, )) + + # sparsify weights + >>> w_true[1:4] = 0. + + # generate counts + >>> rate = jax.numpy.exp(jax.numpy.einsum("k,tk->t", w_true, X) + b_true) + >>> spikes = np.random.poisson(rate) + + # define and fit model + >>> model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1) + >>> model.fit(X, y) + + >>> print("Ridge results") + >>> print("True weights: ", w_true) + >>> print("Recovered weights: ", model.coef_) + """ # validate the inputs & initialize solver init_params = self.initialize_params(X, y, init_params=init_params) From 29b1c4fbdeaf4a9880b9af3a166b19ce55d1bf21 Mon Sep 17 00:00:00 2001 From: Pranati Modumudi Date: Fri, 4 Oct 2024 09:37:02 -0400 Subject: [PATCH 03/18] update glm examples --- src/nemos/glm.py | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/src/nemos/glm.py b/src/nemos/glm.py index 06a5bcf5..31836690 100644 --- a/src/nemos/glm.py +++ b/src/nemos/glm.py @@ -427,8 +427,10 @@ def score( Examples -------- - - >>> + >>> model = nmo.glm.GLM() + >>> model.fit(X, y) + >>> print(f"GLM log-likelihood: {model.score(X, y)}") + >>> print(f"GLM pseudo-r2-McFadden: {model.score(X, y, score_type='pseudo-r2-McFadden')}") Notes ----- @@ -633,26 +635,10 @@ def fit( # fit a ridge regression Poisson GLM >>> import nemos as nmo - # random design tensor. Shape (n_time_points, n_features). - >>> X = 0.5*np.random.normal(size=(100, 5)) - - # set log-rates & weights, shape (1, ) and (n_features, ) respectively. - >>> b_true = np.zeros((1, )) - >>> w_true = np.random.normal(size=(5, )) - - # sparsify weights - >>> w_true[1:4] = 0. - - # generate counts - >>> rate = jax.numpy.exp(jax.numpy.einsum("k,tk->t", w_true, X) + b_true) - >>> spikes = np.random.poisson(rate) - - # define and fit model >>> model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1) >>> model.fit(X, y) >>> print("Ridge results") - >>> print("True weights: ", w_true) >>> print("Recovered weights: ", model.coef_) """ @@ -761,6 +747,11 @@ def simulate( ValueError - If the instance has not been previously fitted. + Examples + -------- + # generate spikes and rates given X + >>> spikes, rates = model.simulate(random_key, X) + See Also -------- From 05d9f42e3e11e744cf03fc04759335b783cde999 Mon Sep 17 00:00:00 2001 From: Pranati Modumudi Date: Fri, 4 Oct 2024 14:51:44 -0400 Subject: [PATCH 04/18] misc --- src/nemos/glm.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/nemos/glm.py b/src/nemos/glm.py index 31836690..8beb9977 100644 --- a/src/nemos/glm.py +++ b/src/nemos/glm.py @@ -322,8 +322,12 @@ def predict(self, X: DESIGN_INPUT_TYPE) -> jnp.ndarray: Examples -------- + # define and fit a GLM >>> model = nmo.glm.GLM() - >>> model.fit(X, y) + >>> model.fit(X_train, y) + + # predict spike data + >>> model.predict(X_test) See Also -------- @@ -429,6 +433,8 @@ def score( -------- >>> model = nmo.glm.GLM() >>> model.fit(X, y) + + # get model score >>> print(f"GLM log-likelihood: {model.score(X, y)}") >>> print(f"GLM pseudo-r2-McFadden: {model.score(X, y, score_type='pseudo-r2-McFadden')}") @@ -750,9 +756,9 @@ def simulate( Examples -------- # generate spikes and rates given X + >>> random_key = jax.random.key(123) >>> spikes, rates = model.simulate(random_key, X) - See Also -------- [predict](./#nemos.glm.GLM.predict) : From 545fd3ac07953180abd1c77e0e755458bead10fa Mon Sep 17 00:00:00 2001 From: Pranati Modumudi Date: Sun, 13 Oct 2024 12:49:27 -0400 Subject: [PATCH 05/18] glm.py examples added --- src/nemos/glm.py | 65 +++++++++++++++++++++++++++++++++++------------- 1 file changed, 48 insertions(+), 17 deletions(-) diff --git a/src/nemos/glm.py b/src/nemos/glm.py index 2149db66..0b4a9f0d 100644 --- a/src/nemos/glm.py +++ b/src/nemos/glm.py @@ -106,16 +106,23 @@ class GLM(BaseRegressor): Examples -------- - >>> from nemos.glm import GLM + >>> import nemos as nmo + # define simple GLM model >>> model = nmo.glm.GLM() - >>> print("Regularization type: ", type(model.regularizer)) + >>> print("Regularizer type: ", type(model.regularizer)) + Regularizer type: >>> print("Observation model: ", type(model.observation_model)) + Observation model: + # define GLM model of PoissonObservations model with soft-plus NL >>> observation_models = nmo.observation_models.PoissonObservations(jax.nn.softplus) - >>> model = nmo.glm.GLM(observation_model=observation_models, \ - ... solver_name="LBFGS") + >>> model = nmo.glm.GLM(observation_model=observation_models, solver_name="LBFGS") + >>> print("Regularizer type: ", type(model.regularizer)) + Regularizer type: + >>> print("Observation model: ", type(model.observation_model)) + Observation model: """ def __init__( @@ -321,12 +328,18 @@ def predict(self, X: DESIGN_INPUT_TYPE) -> jnp.ndarray: Examples -------- - # define and fit a GLM + # example input + >>> import numpy as np + >>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10) + >>> Xnew = np.random.normal(size=(20, ) + X.shape[1:]) + + # define and fit a GLM + >>> import nemos as nmo >>> model = nmo.glm.GLM() - >>> model.fit(X_train, y) + >>> model = model.fit(X, y) - # predict spike data - >>> model.predict(X_test) + # predict new spike data + >>> predicted_spikes = model.predict(Xnew) See Also -------- @@ -430,12 +443,17 @@ def score( Examples -------- + # example input + >>> import numpy as np + >>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10) + + >>> import nemos as nmo >>> model = nmo.glm.GLM() - >>> model.fit(X, y) + >>> model = model.fit(X, y) # get model score - >>> print(f"GLM log-likelihood: {model.score(X, y)}") - >>> print(f"GLM pseudo-r2-McFadden: {model.score(X, y, score_type='pseudo-r2-McFadden')}") + >>> log_likelihood_score = model.score(X, y) + >>> pseudo_r2_score = model.score(X, y, score_type='pseudo-r2-McFadden') Notes ----- @@ -637,14 +655,17 @@ def fit( Examples ------- + # example input + >>> import numpy as np + >>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10) # fit a ridge regression Poisson GLM >>> import nemos as nmo >>> model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1) - >>> model.fit(X, y) - - >>> print("Ridge results") - >>> print("Recovered weights: ", model.coef_) + >>> model = model.fit(X, y) + + # get model weights + >>> model_weights = model.coef_ """ # validate the inputs & initialize solver @@ -741,9 +762,19 @@ def simulate( Examples -------- - # generate spikes and rates given X + # example input + >>> import numpy as np + >>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10) + >>> Xnew = np.random.normal(size=(20, ) + X.shape[1:]) + + # define and fit model + >>> import nemos as nmo + >>> model = nmo.glm.GLM() + >>> model = model.fit(X, y) + + # generate spikes and rates >>> random_key = jax.random.key(123) - >>> spikes, rates = model.simulate(random_key, X) + >>> spikes, rates = model.simulate(random_key, Xnew) See Also -------- From 8dcee76c3a3b564ed269704759807443b7bbd68c Mon Sep 17 00:00:00 2001 From: pranmod01 Date: Sun, 13 Oct 2024 13:14:19 -0400 Subject: [PATCH 06/18] glm examples + format checks --- docs/how_to_guide/plot_02_glm_demo.py | 2 +- docs/how_to_guide/plot_03_glm_pytree.py | 4 +- docs/how_to_guide/plot_04_population_glm.py | 2 +- docs/tutorials/plot_01_current_injection.py | 4 +- docs/tutorials/plot_06_calcium_imaging.py | 4 +- src/nemos/base_regressor.py | 5 +- src/nemos/convolve.py | 3 +- src/nemos/glm.py | 95 +++++++++++---------- src/nemos/observation_models.py | 3 +- src/nemos/regularizer.py | 5 +- src/nemos/simulation.py | 5 +- src/nemos/solvers.py | 7 +- src/nemos/type_casting.py | 5 +- src/nemos/typing.py | 3 +- src/nemos/utils.py | 5 +- src/nemos/validation.py | 3 +- 16 files changed, 83 insertions(+), 72 deletions(-) diff --git a/docs/how_to_guide/plot_02_glm_demo.py b/docs/how_to_guide/plot_02_glm_demo.py index f1c6e3b2..23ef1396 100644 --- a/docs/how_to_guide/plot_02_glm_demo.py +++ b/docs/how_to_guide/plot_02_glm_demo.py @@ -25,12 +25,12 @@ """ -import jax import matplotlib.pyplot as plt import numpy as np from matplotlib.patches import Rectangle from sklearn import model_selection +import jax import nemos as nmo np.random.seed(111) diff --git a/docs/how_to_guide/plot_03_glm_pytree.py b/docs/how_to_guide/plot_03_glm_pytree.py index 2d36db3b..22e0f5f2 100644 --- a/docs/how_to_guide/plot_03_glm_pytree.py +++ b/docs/how_to_guide/plot_03_glm_pytree.py @@ -12,10 +12,10 @@ First, however, let's briefly discuss FeaturePytrees. """ -import jax -import jax.numpy as jnp import numpy as np +import jax +import jax.numpy as jnp import nemos as nmo np.random.seed(111) diff --git a/docs/how_to_guide/plot_04_population_glm.py b/docs/how_to_guide/plot_04_population_glm.py index 70dac9cd..a4e0eec4 100644 --- a/docs/how_to_guide/plot_04_population_glm.py +++ b/docs/how_to_guide/plot_04_population_glm.py @@ -22,10 +22,10 @@ Let's generate some synthetic data and fit a population model. """ -import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np +import jax.numpy as jnp import nemos as nmo np.random.seed(123) diff --git a/docs/tutorials/plot_01_current_injection.py b/docs/tutorials/plot_01_current_injection.py index 8eb83dbe..aeafe377 100644 --- a/docs/tutorials/plot_01_current_injection.py +++ b/docs/tutorials/plot_01_current_injection.py @@ -48,12 +48,12 @@ -# Import everything -import jax import matplotlib.pyplot as plt import numpy as np import pynapple as nap +# Import everything +import jax import nemos as nmo # some helper plotting functions diff --git a/docs/tutorials/plot_06_calcium_imaging.py b/docs/tutorials/plot_06_calcium_imaging.py index 985a0b61..514a0367 100644 --- a/docs/tutorials/plot_06_calcium_imaging.py +++ b/docs/tutorials/plot_06_calcium_imaging.py @@ -10,12 +10,12 @@ """ -import jax -import jax.numpy as jnp import matplotlib.pyplot as plt import pynapple as nap from sklearn.linear_model import LinearRegression +import jax +import jax.numpy as jnp import nemos as nmo # %% diff --git a/src/nemos/base_regressor.py b/src/nemos/base_regressor.py index 5f651313..7d2cba7f 100644 --- a/src/nemos/base_regressor.py +++ b/src/nemos/base_regressor.py @@ -9,11 +9,12 @@ from copy import deepcopy from typing import Any, Dict, NamedTuple, Optional, Tuple, Union -import jax -import jax.numpy as jnp import jaxopt from numpy.typing import ArrayLike, NDArray +import jax +import jax.numpy as jnp + from . import solvers, utils, validation from ._regularizer_builder import AVAILABLE_REGULARIZERS, create_regularizer from .base_class import Base diff --git a/src/nemos/convolve.py b/src/nemos/convolve.py index 6a6294c3..e9a8d1fc 100644 --- a/src/nemos/convolve.py +++ b/src/nemos/convolve.py @@ -8,9 +8,10 @@ from functools import partial from typing import Any, Literal, Optional +from numpy.typing import ArrayLike, NDArray + import jax import jax.numpy as jnp -from numpy.typing import ArrayLike, NDArray from . import type_casting, utils diff --git a/src/nemos/glm.py b/src/nemos/glm.py index 0b4a9f0d..70b4ee9e 100644 --- a/src/nemos/glm.py +++ b/src/nemos/glm.py @@ -7,12 +7,13 @@ from functools import wraps from typing import Any, Callable, Literal, NamedTuple, Optional, Tuple, Union -import jax -import jax.numpy as jnp import jaxopt from numpy.typing import ArrayLike from scipy.optimize import root +import jax +import jax.numpy as jnp + from . import observation_models as obs from . import tree_utils, validation from .base_regressor import BaseRegressor @@ -663,7 +664,7 @@ def fit( >>> import nemos as nmo >>> model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1) >>> model = model.fit(X, y) - + # get model weights >>> model_weights = model.coef_ @@ -736,50 +737,50 @@ def simulate( ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Simulate neural activity in response to a feed-forward input. - Parameters - ---------- - random_key : - jax.random.key for seeding the simulation. - feedforward_input : - External input matrix to the model, representing factors like convolved currents, - light intensities, etc. When not provided, the simulation is done with coupling-only. - Array of shape (n_time_bins, n_basis_input) or pytree of same. - - Returns - ------- - simulated_activity : - Simulated activity (spike counts for PoissonGLMs) for the neuron over time. - Shape: (n_time_bins, ). - firing_rates : - Simulated rates for the neuron over time. Shape, (n_time_bins, ). - - Raises - ------ - NotFittedError - If the model hasn't been fitted prior to calling this method. - ValueError - - If the instance has not been previously fitted. - - Examples - -------- - # example input - >>> import numpy as np - >>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10) - >>> Xnew = np.random.normal(size=(20, ) + X.shape[1:]) - - # define and fit model - >>> import nemos as nmo - >>> model = nmo.glm.GLM() - >>> model = model.fit(X, y) - - # generate spikes and rates - >>> random_key = jax.random.key(123) - >>> spikes, rates = model.simulate(random_key, Xnew) - - See Also - -------- - [predict](./#nemos.glm.GLM.predict) : - Method to predict rates based on the model's parameters. + Parameters + ---------- + random_key : + jax.random.key for seeding the simulation. + feedforward_input : + External input matrix to the model, representing factors like convolved currents, + light intensities, etc. When not provided, the simulation is done with coupling-only. + Array of shape (n_time_bins, n_basis_input) or pytree of same. + + Returns + ------- + simulated_activity : + Simulated activity (spike counts for PoissonGLMs) for the neuron over time. + Shape: (n_time_bins, ). + firing_rates : + Simulated rates for the neuron over time. Shape, (n_time_bins, ). + + Raises + ------ + NotFittedError + If the model hasn't been fitted prior to calling this method. + ValueError + - If the instance has not been previously fitted. + + Examples + -------- + # example input + >>> import numpy as np + >>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10) + >>> Xnew = np.random.normal(size=(20, ) + X.shape[1:]) + + # define and fit model + >>> import nemos as nmo + >>> model = nmo.glm.GLM() + >>> model = model.fit(X, y) + + # generate spikes and rates + >>> random_key = jax.random.key(123) + >>> spikes, rates = model.simulate(random_key, Xnew) + + See Also + -------- + [predict](./#nemos.glm.GLM.predict) : + Method to predict rates based on the model's parameters. """ # check if the model is fit self._check_is_fit() diff --git a/src/nemos/observation_models.py b/src/nemos/observation_models.py index 9d683ae1..de0b8981 100644 --- a/src/nemos/observation_models.py +++ b/src/nemos/observation_models.py @@ -3,9 +3,10 @@ import abc from typing import Callable, Literal, Union +from numpy.typing import NDArray + import jax import jax.numpy as jnp -from numpy.typing import NDArray from . import utils from .base_class import Base diff --git a/src/nemos/regularizer.py b/src/nemos/regularizer.py index 651f65a5..15a2181f 100644 --- a/src/nemos/regularizer.py +++ b/src/nemos/regularizer.py @@ -9,11 +9,12 @@ import abc from typing import Callable, Tuple, Union -import jax -import jax.numpy as jnp import jaxopt from numpy.typing import NDArray +import jax +import jax.numpy as jnp + from . import tree_utils from .base_class import Base from .proximal_operator import prox_group_lasso diff --git a/src/nemos/simulation.py b/src/nemos/simulation.py index e698e702..92c64c12 100644 --- a/src/nemos/simulation.py +++ b/src/nemos/simulation.py @@ -2,12 +2,13 @@ from typing import Callable, Tuple, Union -import jax -import jax.numpy as jnp import numpy as np import scipy.stats as sts from numpy.typing import NDArray +import jax +import jax.numpy as jnp + from . import convolve, validation from .pytrees import FeaturePytree diff --git a/src/nemos/solvers.py b/src/nemos/solvers.py index 4c060609..779d24d0 100644 --- a/src/nemos/solvers.py +++ b/src/nemos/solvers.py @@ -1,13 +1,14 @@ from functools import partial from typing import Callable, NamedTuple, Optional, Union +from jaxopt import OptStep +from jaxopt._src import loop +from jaxopt.prox import prox_none + import jax import jax.flatten_util import jax.numpy as jnp from jax import grad, jit, lax, random -from jaxopt import OptStep -from jaxopt._src import loop -from jaxopt.prox import prox_none from .tree_utils import tree_add_scalar_mul, tree_l2_norm, tree_slice, tree_sub from .typing import KeyArrayLike, Pytree diff --git a/src/nemos/type_casting.py b/src/nemos/type_casting.py index 8a5522b8..02a73dd8 100644 --- a/src/nemos/type_casting.py +++ b/src/nemos/type_casting.py @@ -12,12 +12,13 @@ from functools import wraps from typing import Any, Callable, List, Literal, Optional, Type, Union -import jax -import jax.numpy as jnp import numpy as np import pynapple as nap from numpy.typing import NDArray +import jax +import jax.numpy as jnp + from . import tree_utils _NAP_TIME_PRECISION = 10 ** (-nap.nap_config.time_index_precision) diff --git a/src/nemos/typing.py b/src/nemos/typing.py index dd9bc5a6..fa86ca82 100644 --- a/src/nemos/typing.py +++ b/src/nemos/typing.py @@ -2,8 +2,9 @@ from typing import Any, Callable, NamedTuple, Tuple, Union -import jax.numpy as jnp import jaxopt + +import jax.numpy as jnp from jax._src.typing import ArrayLike from .pytrees import FeaturePytree diff --git a/src/nemos/utils.py b/src/nemos/utils.py index 87ad0472..f457e4f1 100644 --- a/src/nemos/utils.py +++ b/src/nemos/utils.py @@ -3,11 +3,12 @@ import warnings from typing import Any, Callable, List, Literal, Union -import jax -import jax.numpy as jnp import numpy as np from numpy.typing import NDArray +import jax +import jax.numpy as jnp + from .tree_utils import pytree_map_and_reduce from .type_casting import is_numpy_array_like, support_pynapple diff --git a/src/nemos/validation.py b/src/nemos/validation.py index 9dd0853f..c3ff63bf 100644 --- a/src/nemos/validation.py +++ b/src/nemos/validation.py @@ -3,9 +3,10 @@ import warnings from typing import Any, Optional, Union +from numpy.typing import DTypeLike, NDArray + import jax import jax.numpy as jnp -from numpy.typing import DTypeLike, NDArray from .pytrees import FeaturePytree from .tree_utils import get_valid_multitree, pytree_map_and_reduce From e56e47579d69b021be4b8ce3d163d4e0439a3af1 Mon Sep 17 00:00:00 2001 From: pranmod01 Date: Fri, 18 Oct 2024 08:43:10 -0700 Subject: [PATCH 07/18] random changes --- jax | 1 + src/nemos/basis.py | 13 +++++++++++++ 2 files changed, 14 insertions(+) create mode 160000 jax diff --git a/jax b/jax new file mode 160000 index 00000000..9cf952a5 --- /dev/null +++ b/jax @@ -0,0 +1 @@ +Subproject commit 9cf952a535518da59cdcecc9145dba287beddca2 diff --git a/src/nemos/basis.py b/src/nemos/basis.py index f5907480..ecb69645 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -204,6 +204,19 @@ def fit(self, X: FeatureMatrix, y=None): ------- self : The transformer object. + + Examples + -------- + # Example input + >>> import numpy as np + >>> X, y = np.random.normal(size=(100, 2)), np.random.uniform(size=100) + + # Define and fit tranformation basis + >>> from nemos.basis import MSplineBasis, TransformerBasis + >>> basis = MSplineBasis(10) + >>> transformer = TransformerBasis(basis) + >>> transformer = transformer.fit(X) + """ self._basis._set_kernel(*self._unpack_inputs(X)) return self From a9799a38b810c8ac7a280521f4c09910541e26fd Mon Sep 17 00:00:00 2001 From: pranmod01 Date: Fri, 18 Oct 2024 09:04:37 -0700 Subject: [PATCH 08/18] isort fix --- jax | 1 - src/nemos/__init__.py | 2 +- src/nemos/_documentation_utils/__init__.py | 2 +- src/nemos/base_regressor.py | 5 ++--- src/nemos/convolve.py | 3 +-- src/nemos/glm.py | 5 ++--- src/nemos/observation_models.py | 3 +-- src/nemos/regularizer.py | 5 ++--- src/nemos/simulation.py | 5 ++--- src/nemos/solvers.py | 7 +++---- src/nemos/type_casting.py | 5 ++--- src/nemos/typing.py | 3 +-- src/nemos/utils.py | 5 ++--- src/nemos/validation.py | 3 +-- 14 files changed, 21 insertions(+), 33 deletions(-) delete mode 160000 jax diff --git a/jax b/jax deleted file mode 160000 index 9cf952a5..00000000 --- a/jax +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 9cf952a535518da59cdcecc9145dba287beddca2 diff --git a/src/nemos/__init__.py b/src/nemos/__init__.py index 97c5b3db..aedd05c0 100644 --- a/src/nemos/__init__.py +++ b/src/nemos/__init__.py @@ -14,5 +14,5 @@ styles, tree_utils, type_casting, - utils, + utils ) diff --git a/src/nemos/_documentation_utils/__init__.py b/src/nemos/_documentation_utils/__init__.py index 3cd63e0e..1c64a43a 100644 --- a/src/nemos/_documentation_utils/__init__.py +++ b/src/nemos/_documentation_utils/__init__.py @@ -19,5 +19,5 @@ plot_rates_and_smoothed_counts, plot_weighted_sum_basis, run_animation, - tuning_curve_plot, + tuning_curve_plot ) diff --git a/src/nemos/base_regressor.py b/src/nemos/base_regressor.py index ba2782ed..e4a425ce 100644 --- a/src/nemos/base_regressor.py +++ b/src/nemos/base_regressor.py @@ -9,11 +9,10 @@ from copy import deepcopy from typing import Any, Dict, NamedTuple, Optional, Tuple, Union -import jaxopt -from numpy.typing import ArrayLike, NDArray - import jax import jax.numpy as jnp +import jaxopt +from numpy.typing import ArrayLike, NDArray from . import solvers, utils, validation from ._regularizer_builder import AVAILABLE_REGULARIZERS, create_regularizer diff --git a/src/nemos/convolve.py b/src/nemos/convolve.py index e9a8d1fc..6a6294c3 100644 --- a/src/nemos/convolve.py +++ b/src/nemos/convolve.py @@ -8,10 +8,9 @@ from functools import partial from typing import Any, Literal, Optional -from numpy.typing import ArrayLike, NDArray - import jax import jax.numpy as jnp +from numpy.typing import ArrayLike, NDArray from . import type_casting, utils diff --git a/src/nemos/glm.py b/src/nemos/glm.py index 70b4ee9e..44cbb4d4 100644 --- a/src/nemos/glm.py +++ b/src/nemos/glm.py @@ -7,13 +7,12 @@ from functools import wraps from typing import Any, Callable, Literal, NamedTuple, Optional, Tuple, Union +import jax +import jax.numpy as jnp import jaxopt from numpy.typing import ArrayLike from scipy.optimize import root -import jax -import jax.numpy as jnp - from . import observation_models as obs from . import tree_utils, validation from .base_regressor import BaseRegressor diff --git a/src/nemos/observation_models.py b/src/nemos/observation_models.py index de0b8981..9d683ae1 100644 --- a/src/nemos/observation_models.py +++ b/src/nemos/observation_models.py @@ -3,10 +3,9 @@ import abc from typing import Callable, Literal, Union -from numpy.typing import NDArray - import jax import jax.numpy as jnp +from numpy.typing import NDArray from . import utils from .base_class import Base diff --git a/src/nemos/regularizer.py b/src/nemos/regularizer.py index c2cc75d6..6d6cf0bd 100644 --- a/src/nemos/regularizer.py +++ b/src/nemos/regularizer.py @@ -9,11 +9,10 @@ import abc from typing import Callable, Tuple, Union -import jaxopt -from numpy.typing import NDArray - import jax import jax.numpy as jnp +import jaxopt +from numpy.typing import NDArray from . import tree_utils from .base_class import Base diff --git a/src/nemos/simulation.py b/src/nemos/simulation.py index 92c64c12..e698e702 100644 --- a/src/nemos/simulation.py +++ b/src/nemos/simulation.py @@ -2,13 +2,12 @@ from typing import Callable, Tuple, Union +import jax +import jax.numpy as jnp import numpy as np import scipy.stats as sts from numpy.typing import NDArray -import jax -import jax.numpy as jnp - from . import convolve, validation from .pytrees import FeaturePytree diff --git a/src/nemos/solvers.py b/src/nemos/solvers.py index 779d24d0..4c060609 100644 --- a/src/nemos/solvers.py +++ b/src/nemos/solvers.py @@ -1,14 +1,13 @@ from functools import partial from typing import Callable, NamedTuple, Optional, Union -from jaxopt import OptStep -from jaxopt._src import loop -from jaxopt.prox import prox_none - import jax import jax.flatten_util import jax.numpy as jnp from jax import grad, jit, lax, random +from jaxopt import OptStep +from jaxopt._src import loop +from jaxopt.prox import prox_none from .tree_utils import tree_add_scalar_mul, tree_l2_norm, tree_slice, tree_sub from .typing import KeyArrayLike, Pytree diff --git a/src/nemos/type_casting.py b/src/nemos/type_casting.py index 02a73dd8..8a5522b8 100644 --- a/src/nemos/type_casting.py +++ b/src/nemos/type_casting.py @@ -12,13 +12,12 @@ from functools import wraps from typing import Any, Callable, List, Literal, Optional, Type, Union +import jax +import jax.numpy as jnp import numpy as np import pynapple as nap from numpy.typing import NDArray -import jax -import jax.numpy as jnp - from . import tree_utils _NAP_TIME_PRECISION = 10 ** (-nap.nap_config.time_index_precision) diff --git a/src/nemos/typing.py b/src/nemos/typing.py index fa86ca82..dd9bc5a6 100644 --- a/src/nemos/typing.py +++ b/src/nemos/typing.py @@ -2,9 +2,8 @@ from typing import Any, Callable, NamedTuple, Tuple, Union -import jaxopt - import jax.numpy as jnp +import jaxopt from jax._src.typing import ArrayLike from .pytrees import FeaturePytree diff --git a/src/nemos/utils.py b/src/nemos/utils.py index f457e4f1..87ad0472 100644 --- a/src/nemos/utils.py +++ b/src/nemos/utils.py @@ -3,11 +3,10 @@ import warnings from typing import Any, Callable, List, Literal, Union -import numpy as np -from numpy.typing import NDArray - import jax import jax.numpy as jnp +import numpy as np +from numpy.typing import NDArray from .tree_utils import pytree_map_and_reduce from .type_casting import is_numpy_array_like, support_pynapple diff --git a/src/nemos/validation.py b/src/nemos/validation.py index c3ff63bf..9dd0853f 100644 --- a/src/nemos/validation.py +++ b/src/nemos/validation.py @@ -3,10 +3,9 @@ import warnings from typing import Any, Optional, Union -from numpy.typing import DTypeLike, NDArray - import jax import jax.numpy as jnp +from numpy.typing import DTypeLike, NDArray from .pytrees import FeaturePytree from .tree_utils import get_valid_multitree, pytree_map_and_reduce From 929da17c1917d28f86bd53a250d2b69a4cfe5544 Mon Sep 17 00:00:00 2001 From: pranmod01 Date: Fri, 18 Oct 2024 09:10:00 -0700 Subject: [PATCH 09/18] formatting fixed --- docs/how_to_guide/plot_02_glm_demo.py | 2 +- docs/how_to_guide/plot_03_glm_pytree.py | 4 ++-- docs/how_to_guide/plot_04_population_glm.py | 2 +- docs/tutorials/plot_01_current_injection.py | 4 ++-- docs/tutorials/plot_06_calcium_imaging.py | 4 ++-- src/nemos/__init__.py | 2 +- src/nemos/_documentation_utils/__init__.py | 2 +- 7 files changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/how_to_guide/plot_02_glm_demo.py b/docs/how_to_guide/plot_02_glm_demo.py index 23ef1396..f1c6e3b2 100644 --- a/docs/how_to_guide/plot_02_glm_demo.py +++ b/docs/how_to_guide/plot_02_glm_demo.py @@ -25,12 +25,12 @@ """ +import jax import matplotlib.pyplot as plt import numpy as np from matplotlib.patches import Rectangle from sklearn import model_selection -import jax import nemos as nmo np.random.seed(111) diff --git a/docs/how_to_guide/plot_03_glm_pytree.py b/docs/how_to_guide/plot_03_glm_pytree.py index 22e0f5f2..2d36db3b 100644 --- a/docs/how_to_guide/plot_03_glm_pytree.py +++ b/docs/how_to_guide/plot_03_glm_pytree.py @@ -12,10 +12,10 @@ First, however, let's briefly discuss FeaturePytrees. """ -import numpy as np - import jax import jax.numpy as jnp +import numpy as np + import nemos as nmo np.random.seed(111) diff --git a/docs/how_to_guide/plot_04_population_glm.py b/docs/how_to_guide/plot_04_population_glm.py index a4e0eec4..70dac9cd 100644 --- a/docs/how_to_guide/plot_04_population_glm.py +++ b/docs/how_to_guide/plot_04_population_glm.py @@ -22,10 +22,10 @@ Let's generate some synthetic data and fit a population model. """ +import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np -import jax.numpy as jnp import nemos as nmo np.random.seed(123) diff --git a/docs/tutorials/plot_01_current_injection.py b/docs/tutorials/plot_01_current_injection.py index aeafe377..8eb83dbe 100644 --- a/docs/tutorials/plot_01_current_injection.py +++ b/docs/tutorials/plot_01_current_injection.py @@ -48,12 +48,12 @@ +# Import everything +import jax import matplotlib.pyplot as plt import numpy as np import pynapple as nap -# Import everything -import jax import nemos as nmo # some helper plotting functions diff --git a/docs/tutorials/plot_06_calcium_imaging.py b/docs/tutorials/plot_06_calcium_imaging.py index 514a0367..985a0b61 100644 --- a/docs/tutorials/plot_06_calcium_imaging.py +++ b/docs/tutorials/plot_06_calcium_imaging.py @@ -10,12 +10,12 @@ """ +import jax +import jax.numpy as jnp import matplotlib.pyplot as plt import pynapple as nap from sklearn.linear_model import LinearRegression -import jax -import jax.numpy as jnp import nemos as nmo # %% diff --git a/src/nemos/__init__.py b/src/nemos/__init__.py index aedd05c0..97c5b3db 100644 --- a/src/nemos/__init__.py +++ b/src/nemos/__init__.py @@ -14,5 +14,5 @@ styles, tree_utils, type_casting, - utils + utils, ) diff --git a/src/nemos/_documentation_utils/__init__.py b/src/nemos/_documentation_utils/__init__.py index 1c64a43a..3cd63e0e 100644 --- a/src/nemos/_documentation_utils/__init__.py +++ b/src/nemos/_documentation_utils/__init__.py @@ -19,5 +19,5 @@ plot_rates_and_smoothed_counts, plot_weighted_sum_basis, run_animation, - tuning_curve_plot + tuning_curve_plot, ) From 5aba7d64b22f79b0361448f80f539d31537583d6 Mon Sep 17 00:00:00 2001 From: pranmod01 Date: Mon, 21 Oct 2024 22:35:02 -0700 Subject: [PATCH 10/18] fixed comments, added example for popglm fit, removed misc basis examples --- src/nemos/basis.py | 12 ----- src/nemos/glm.py | 128 ++++++++++++++++++++++++++------------------- 2 files changed, 74 insertions(+), 66 deletions(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 51127a31..f98c7abb 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -205,18 +205,6 @@ def fit(self, X: FeatureMatrix, y=None): self : The transformer object. - Examples - -------- - # Example input - >>> import numpy as np - >>> X, y = np.random.normal(size=(100, 2)), np.random.uniform(size=100) - - # Define and fit tranformation basis - >>> from nemos.basis import MSplineBasis, TransformerBasis - >>> basis = MSplineBasis(10) - >>> transformer = TransformerBasis(basis) - >>> transformer = transformer.fit(X) - """ self._basis._set_kernel(*self._unpack_inputs(X)) return self diff --git a/src/nemos/glm.py b/src/nemos/glm.py index 44cbb4d4..bbb6daf2 100644 --- a/src/nemos/glm.py +++ b/src/nemos/glm.py @@ -108,7 +108,7 @@ class GLM(BaseRegressor): -------- >>> import nemos as nmo - # define simple GLM model + >>> # define simple GLM model >>> model = nmo.glm.GLM() >>> print("Regularizer type: ", type(model.regularizer)) Regularizer type: @@ -116,7 +116,7 @@ class GLM(BaseRegressor): Observation model: - # define GLM model of PoissonObservations model with soft-plus NL + >>> # define GLM model of PoissonObservations model with soft-plus NL >>> observation_models = nmo.observation_models.PoissonObservations(jax.nn.softplus) >>> model = nmo.glm.GLM(observation_model=observation_models, solver_name="LBFGS") >>> print("Regularizer type: ", type(model.regularizer)) @@ -328,17 +328,17 @@ def predict(self, X: DESIGN_INPUT_TYPE) -> jnp.ndarray: Examples -------- - # example input + >>> # example input >>> import numpy as np >>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10) >>> Xnew = np.random.normal(size=(20, ) + X.shape[1:]) - # define and fit a GLM + >>> # define and fit a GLM >>> import nemos as nmo >>> model = nmo.glm.GLM() >>> model = model.fit(X, y) - # predict new spike data + >>> # predict new spike data >>> predicted_spikes = model.predict(Xnew) See Also @@ -443,7 +443,7 @@ def score( Examples -------- - # example input + >>> # example input >>> import numpy as np >>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10) @@ -451,7 +451,7 @@ def score( >>> model = nmo.glm.GLM() >>> model = model.fit(X, y) - # get model score + >>> # get model score >>> log_likelihood_score = model.score(X, y) >>> pseudo_r2_score = model.score(X, y, score_type='pseudo-r2-McFadden') @@ -655,16 +655,16 @@ def fit( Examples ------- - # example input + >>> # example input >>> import numpy as np >>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10) - # fit a ridge regression Poisson GLM + >>> # fit a ridge regression Poisson GLM >>> import nemos as nmo >>> model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1) >>> model = model.fit(X, y) - # get model weights + >>> # get model weights >>> model_weights = model.coef_ """ @@ -736,50 +736,50 @@ def simulate( ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Simulate neural activity in response to a feed-forward input. - Parameters - ---------- - random_key : - jax.random.key for seeding the simulation. - feedforward_input : - External input matrix to the model, representing factors like convolved currents, - light intensities, etc. When not provided, the simulation is done with coupling-only. - Array of shape (n_time_bins, n_basis_input) or pytree of same. - - Returns - ------- - simulated_activity : - Simulated activity (spike counts for PoissonGLMs) for the neuron over time. - Shape: (n_time_bins, ). - firing_rates : - Simulated rates for the neuron over time. Shape, (n_time_bins, ). - - Raises - ------ - NotFittedError - If the model hasn't been fitted prior to calling this method. - ValueError - - If the instance has not been previously fitted. - - Examples - -------- - # example input - >>> import numpy as np - >>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10) - >>> Xnew = np.random.normal(size=(20, ) + X.shape[1:]) - - # define and fit model - >>> import nemos as nmo - >>> model = nmo.glm.GLM() - >>> model = model.fit(X, y) - - # generate spikes and rates - >>> random_key = jax.random.key(123) - >>> spikes, rates = model.simulate(random_key, Xnew) - - See Also - -------- - [predict](./#nemos.glm.GLM.predict) : - Method to predict rates based on the model's parameters. + Parameters + ---------- + random_key : + jax.random.key for seeding the simulation. + feedforward_input : + External input matrix to the model, representing factors like convolved currents, + light intensities, etc. When not provided, the simulation is done with coupling-only. + Array of shape (n_time_bins, n_basis_input) or pytree of same. + + Returns + ------- + simulated_activity : + Simulated activity (spike counts for Poisson GLMs) for the neuron over time. + Shape: (n_time_bins, ). + firing_rates : + Simulated rates for the neuron over time. Shape, (n_time_bins, ). + + Raises + ------ + NotFittedError + If the model hasn't been fitted prior to calling this method. + ValueError + - If the instance has not been previously fitted. + + Examples + -------- + >>> # example input + >>> import numpy as np + >>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10) + >>> Xnew = np.random.normal(size=(20, ) + X.shape[1:]) + + >>> # define and fit model + >>> import nemos as nmo + >>> model = nmo.glm.GLM() + >>> model = model.fit(X, y) + + >>> # generate spikes and rates + >>> random_key = jax.random.key(123) + >>> spikes, rates = model.simulate(random_key, Xnew) + + See Also + -------- + [predict](./#nemos.glm.GLM.predict) : + Method to predict rates based on the model's parameters. """ # check if the model is fit self._check_is_fit() @@ -1453,6 +1453,26 @@ def fit( - If the mask is a `FeaturePytree`, then `"feature_name"` is a predictor of neuron `j` if `feature_mask["feature_name"][j] == 1`. + Examples + -------- + >>> # Generate sample data + >>> import jax.numpy as jnp + >>> import numpy as np + >>> from nemos.glm import PopulationGLM + + >>> # Define predictors (X), weights, and neural activity (y) + >>> num_samples, num_features, num_neurons = 100, 3, 2 + >>> X = np.random.normal(size=(num_samples, num_features)) + >>> weights = np.array([[ 0.5, 0. ], [-0.5, -0.5], [ 0. , 1. ]]) + >>> y = np.random.poisson(np.exp(X.dot(weights))) + + >>> # Define a feature mask, shape (num_features, num_neurons) + >>> feature_mask = jnp.array([[1, 0], [1, 1], [0, 1]]) + + >>> # Create and fit the model + >>> model = PopulationGLM(feature_mask=feature_mask).fit(X, y) + >>> print(model.coef_.shape) + (3, 2) """ return super().fit(X, y, init_params) From 544b8703cbe570a8a1b00c3bcaef2ab8385d70a3 Mon Sep 17 00:00:00 2001 From: pranmod01 Date: Wed, 23 Oct 2024 09:33:30 -0700 Subject: [PATCH 11/18] fixed oct 23 comments --- src/nemos/basis.py | 1 - src/nemos/glm.py | 12 +++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index f98c7abb..1b0c9f12 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -204,7 +204,6 @@ def fit(self, X: FeatureMatrix, y=None): ------- self : The transformer object. - """ self._basis._set_kernel(*self._unpack_inputs(X)) return self diff --git a/src/nemos/glm.py b/src/nemos/glm.py index bbb6daf2..df03eb5e 100644 --- a/src/nemos/glm.py +++ b/src/nemos/glm.py @@ -331,7 +331,6 @@ def predict(self, X: DESIGN_INPUT_TYPE) -> jnp.ndarray: >>> # example input >>> import numpy as np >>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10) - >>> Xnew = np.random.normal(size=(20, ) + X.shape[1:]) >>> # define and fit a GLM >>> import nemos as nmo @@ -339,6 +338,7 @@ def predict(self, X: DESIGN_INPUT_TYPE) -> jnp.ndarray: >>> model = model.fit(X, y) >>> # predict new spike data + >>> Xnew = np.random.normal(size=(20, ) + X.shape[1:]) >>> predicted_spikes = model.predict(Xnew) See Also @@ -453,7 +453,6 @@ def score( >>> # get model score >>> log_likelihood_score = model.score(X, y) - >>> pseudo_r2_score = model.score(X, y, score_type='pseudo-r2-McFadden') Notes ----- @@ -664,8 +663,9 @@ def fit( >>> model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1) >>> model = model.fit(X, y) - >>> # get model weights + >>> # get model weights and intercept >>> model_weights = model.coef_ + >>> model_intercept = model.intercept_ """ # validate the inputs & initialize solver @@ -765,7 +765,6 @@ def simulate( >>> # example input >>> import numpy as np >>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10) - >>> Xnew = np.random.normal(size=(20, ) + X.shape[1:]) >>> # define and fit model >>> import nemos as nmo @@ -774,6 +773,7 @@ def simulate( >>> # generate spikes and rates >>> random_key = jax.random.key(123) + >>> Xnew = np.random.normal(size=(20, ) + X.shape[1:]) >>> spikes, rates = model.simulate(random_key, Xnew) See Also @@ -1463,7 +1463,9 @@ def fit( >>> # Define predictors (X), weights, and neural activity (y) >>> num_samples, num_features, num_neurons = 100, 3, 2 >>> X = np.random.normal(size=(num_samples, num_features)) - >>> weights = np.array([[ 0.5, 0. ], [-0.5, -0.5], [ 0. , 1. ]]) + >>> # Weights is defined by how each feature influences the output, shape (num_features, num_neurons) + >>> weights = np.array([[ 0.5, 0. ], [-0.5, -0.5], [ 0. , 1. ]]) + >>> # Output y simulates a Poisson distribution based on a linear model between features X and wegihts >>> y = np.random.poisson(np.exp(X.dot(weights))) >>> # Define a feature mask, shape (num_features, num_neurons) From 7439e49954cae5953726df5a51d282e356bf00b8 Mon Sep 17 00:00:00 2001 From: pranmod01 Date: Wed, 23 Oct 2024 09:45:03 -0700 Subject: [PATCH 12/18] added example to initialize_state --- src/nemos/glm.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/nemos/glm.py b/src/nemos/glm.py index df03eb5e..03f15638 100644 --- a/src/nemos/glm.py +++ b/src/nemos/glm.py @@ -948,6 +948,16 @@ def initialize_state( ------- NamedTuple The initialized solver state + + Examples + -------- + >>> import numpy as np + >>> import nemos as nmo + >>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10) + >>> model = nmo.glm.GLM() + >>> params = model.initialize_params(X, y) + >>> opt_state = model.initialize_state(X, y, params) + >>> # Now ready to run optimization or update steps """ if isinstance(X, FeaturePytree): data = X.data From cfc0045700f936b48b1fab76d9dee390583762b0 Mon Sep 17 00:00:00 2001 From: pranmod01 Date: Wed, 23 Oct 2024 09:49:47 -0700 Subject: [PATCH 13/18] tox fixes --- src/nemos/glm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nemos/glm.py b/src/nemos/glm.py index 03f15638..3e90c1b4 100644 --- a/src/nemos/glm.py +++ b/src/nemos/glm.py @@ -1474,7 +1474,7 @@ def fit( >>> num_samples, num_features, num_neurons = 100, 3, 2 >>> X = np.random.normal(size=(num_samples, num_features)) >>> # Weights is defined by how each feature influences the output, shape (num_features, num_neurons) - >>> weights = np.array([[ 0.5, 0. ], [-0.5, -0.5], [ 0. , 1. ]]) + >>> weights = np.array([[ 0.5, 0. ], [-0.5, -0.5], [ 0. , 1. ]]) >>> # Output y simulates a Poisson distribution based on a linear model between features X and wegihts >>> y = np.random.poisson(np.exp(X.dot(weights))) From 94d87de679c20c889358727100ef8a305148c5d8 Mon Sep 17 00:00:00 2001 From: Edoardo Balzani Date: Fri, 25 Oct 2024 16:12:53 -0400 Subject: [PATCH 14/18] Update src/nemos/glm.py Co-authored-by: William F. Broderick --- src/nemos/glm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nemos/glm.py b/src/nemos/glm.py index 3e90c1b4..908ef5d9 100644 --- a/src/nemos/glm.py +++ b/src/nemos/glm.py @@ -338,7 +338,7 @@ def predict(self, X: DESIGN_INPUT_TYPE) -> jnp.ndarray: >>> model = model.fit(X, y) >>> # predict new spike data - >>> Xnew = np.random.normal(size=(20, ) + X.shape[1:]) + >>> Xnew = np.random.normal(size=(20, X.shape[1])) >>> predicted_spikes = model.predict(Xnew) See Also From 71cfe4ca485b81a32ad98aaaf74188e5bc37f3c0 Mon Sep 17 00:00:00 2001 From: Edoardo Balzani Date: Fri, 25 Oct 2024 16:13:07 -0400 Subject: [PATCH 15/18] Update src/nemos/glm.py Co-authored-by: William F. Broderick --- src/nemos/glm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nemos/glm.py b/src/nemos/glm.py index 908ef5d9..787e5d65 100644 --- a/src/nemos/glm.py +++ b/src/nemos/glm.py @@ -773,7 +773,7 @@ def simulate( >>> # generate spikes and rates >>> random_key = jax.random.key(123) - >>> Xnew = np.random.normal(size=(20, ) + X.shape[1:]) + >>> Xnew = np.random.normal(size=(20, X.shape[1])) >>> spikes, rates = model.simulate(random_key, Xnew) See Also From b0fcfe4e3c3447780df9c5bead8851313bc48f45 Mon Sep 17 00:00:00 2001 From: pranmod01 Date: Sat, 26 Oct 2024 12:46:48 -0700 Subject: [PATCH 16/18] latest comment fixes from 10/25 --- src/nemos/glm.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/nemos/glm.py b/src/nemos/glm.py index 787e5d65..c6e1a392 100644 --- a/src/nemos/glm.py +++ b/src/nemos/glm.py @@ -108,7 +108,7 @@ class GLM(BaseRegressor): -------- >>> import nemos as nmo - >>> # define simple GLM model + >>> # define single neuron GLM model >>> model = nmo.glm.GLM() >>> print("Regularizer type: ", type(model.regularizer)) Regularizer type: @@ -330,7 +330,7 @@ def predict(self, X: DESIGN_INPUT_TYPE) -> jnp.ndarray: -------- >>> # example input >>> import numpy as np - >>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10) + >>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10) >>> # define and fit a GLM >>> import nemos as nmo @@ -445,7 +445,7 @@ def score( -------- >>> # example input >>> import numpy as np - >>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10) + >>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10) >>> import nemos as nmo >>> model = nmo.glm.GLM() @@ -454,6 +454,9 @@ def score( >>> # get model score >>> log_likelihood_score = model.score(X, y) + >>> # get a pseudo-R2 score + >>> pseudo_r2_score = model.score(X, y, score_type='pseudo-r2-McFadden') + Notes ----- The log-likelihood is not on a standard scale, its value is influenced by many factors, From d4af340371139c6558a8bb1576a6e2503e6ad2c7 Mon Sep 17 00:00:00 2001 From: pranmod01 Date: Sat, 26 Oct 2024 12:48:03 -0700 Subject: [PATCH 17/18] tox fix --- src/nemos/glm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/nemos/glm.py b/src/nemos/glm.py index c6e1a392..31764e98 100644 --- a/src/nemos/glm.py +++ b/src/nemos/glm.py @@ -445,7 +445,7 @@ def score( -------- >>> # example input >>> import numpy as np - >>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10) + >>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10) >>> import nemos as nmo >>> model = nmo.glm.GLM() @@ -454,7 +454,7 @@ def score( >>> # get model score >>> log_likelihood_score = model.score(X, y) - >>> # get a pseudo-R2 score + >>> # get a pseudo-R2 score >>> pseudo_r2_score = model.score(X, y, score_type='pseudo-r2-McFadden') Notes From 12a3b161f5ab88d0384d9bcb2fcb625b94503caf Mon Sep 17 00:00:00 2001 From: pranmod01 Date: Sat, 26 Oct 2024 12:52:05 -0700 Subject: [PATCH 18/18] fixed comments from 10/25 + tox --- src/nemos/glm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/nemos/glm.py b/src/nemos/glm.py index 31764e98..74cbdee5 100644 --- a/src/nemos/glm.py +++ b/src/nemos/glm.py @@ -659,7 +659,7 @@ def fit( ------- >>> # example input >>> import numpy as np - >>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10) + >>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10) >>> # fit a ridge regression Poisson GLM >>> import nemos as nmo @@ -767,7 +767,7 @@ def simulate( -------- >>> # example input >>> import numpy as np - >>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10) + >>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10) >>> # define and fit model >>> import nemos as nmo @@ -956,7 +956,7 @@ def initialize_state( -------- >>> import numpy as np >>> import nemos as nmo - >>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10) + >>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10) >>> model = nmo.glm.GLM() >>> params = model.initialize_params(X, y) >>> opt_state = model.initialize_state(X, y, params)