Skip to content

Commit

Permalink
fixed comments, added example for popglm fit, removed misc basis exam…
Browse files Browse the repository at this point in the history
…ples
  • Loading branch information
pranmod01 committed Oct 22, 2024
1 parent 55e6d64 commit 5aba7d6
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 66 deletions.
12 changes: 0 additions & 12 deletions src/nemos/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,18 +205,6 @@ def fit(self, X: FeatureMatrix, y=None):
self :
The transformer object.
Examples
--------
# Example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(100, 2)), np.random.uniform(size=100)
# Define and fit tranformation basis
>>> from nemos.basis import MSplineBasis, TransformerBasis
>>> basis = MSplineBasis(10)
>>> transformer = TransformerBasis(basis)
>>> transformer = transformer.fit(X)
"""
self._basis._set_kernel(*self._unpack_inputs(X))
return self
Expand Down
128 changes: 74 additions & 54 deletions src/nemos/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,15 @@ class GLM(BaseRegressor):
--------
>>> import nemos as nmo
# define simple GLM model
>>> # define simple GLM model
>>> model = nmo.glm.GLM()
>>> print("Regularizer type: ", type(model.regularizer))
Regularizer type: <class 'nemos.regularizer.UnRegularized'>
>>> print("Observation model: ", type(model.observation_model))
Observation model: <class 'nemos.observation_models.PoissonObservations'>
# define GLM model of PoissonObservations model with soft-plus NL
>>> # define GLM model of PoissonObservations model with soft-plus NL
>>> observation_models = nmo.observation_models.PoissonObservations(jax.nn.softplus)
>>> model = nmo.glm.GLM(observation_model=observation_models, solver_name="LBFGS")
>>> print("Regularizer type: ", type(model.regularizer))
Expand Down Expand Up @@ -328,17 +328,17 @@ def predict(self, X: DESIGN_INPUT_TYPE) -> jnp.ndarray:
Examples
--------
# example input
>>> # example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10)
>>> Xnew = np.random.normal(size=(20, ) + X.shape[1:])
# define and fit a GLM
>>> # define and fit a GLM
>>> import nemos as nmo
>>> model = nmo.glm.GLM()
>>> model = model.fit(X, y)
# predict new spike data
>>> # predict new spike data
>>> predicted_spikes = model.predict(Xnew)
See Also
Expand Down Expand Up @@ -443,15 +443,15 @@ def score(
Examples
--------
# example input
>>> # example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10)
>>> import nemos as nmo
>>> model = nmo.glm.GLM()
>>> model = model.fit(X, y)
# get model score
>>> # get model score
>>> log_likelihood_score = model.score(X, y)
>>> pseudo_r2_score = model.score(X, y, score_type='pseudo-r2-McFadden')
Expand Down Expand Up @@ -655,16 +655,16 @@ def fit(
Examples
-------
# example input
>>> # example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10)
# fit a ridge regression Poisson GLM
>>> # fit a ridge regression Poisson GLM
>>> import nemos as nmo
>>> model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1)
>>> model = model.fit(X, y)
# get model weights
>>> # get model weights
>>> model_weights = model.coef_
"""
Expand Down Expand Up @@ -736,50 +736,50 @@ def simulate(
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Simulate neural activity in response to a feed-forward input.
Parameters
----------
random_key :
jax.random.key for seeding the simulation.
feedforward_input :
External input matrix to the model, representing factors like convolved currents,
light intensities, etc. When not provided, the simulation is done with coupling-only.
Array of shape (n_time_bins, n_basis_input) or pytree of same.
Returns
-------
simulated_activity :
Simulated activity (spike counts for PoissonGLMs) for the neuron over time.
Shape: (n_time_bins, ).
firing_rates :
Simulated rates for the neuron over time. Shape, (n_time_bins, ).
Raises
------
NotFittedError
If the model hasn't been fitted prior to calling this method.
ValueError
- If the instance has not been previously fitted.
Examples
--------
# example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10)
>>> Xnew = np.random.normal(size=(20, ) + X.shape[1:])
# define and fit model
>>> import nemos as nmo
>>> model = nmo.glm.GLM()
>>> model = model.fit(X, y)
# generate spikes and rates
>>> random_key = jax.random.key(123)
>>> spikes, rates = model.simulate(random_key, Xnew)
See Also
--------
[predict](./#nemos.glm.GLM.predict) :
Method to predict rates based on the model's parameters.
Parameters
----------
random_key :
jax.random.key for seeding the simulation.
feedforward_input :
External input matrix to the model, representing factors like convolved currents,
light intensities, etc. When not provided, the simulation is done with coupling-only.
Array of shape (n_time_bins, n_basis_input) or pytree of same.
Returns
-------
simulated_activity :
Simulated activity (spike counts for Poisson GLMs) for the neuron over time.
Shape: (n_time_bins, ).
firing_rates :
Simulated rates for the neuron over time. Shape, (n_time_bins, ).
Raises
------
NotFittedError
If the model hasn't been fitted prior to calling this method.
ValueError
- If the instance has not been previously fitted.
Examples
--------
>>> # example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10)
>>> Xnew = np.random.normal(size=(20, ) + X.shape[1:])
>>> # define and fit model
>>> import nemos as nmo
>>> model = nmo.glm.GLM()
>>> model = model.fit(X, y)
>>> # generate spikes and rates
>>> random_key = jax.random.key(123)
>>> spikes, rates = model.simulate(random_key, Xnew)
See Also
--------
[predict](./#nemos.glm.GLM.predict) :
Method to predict rates based on the model's parameters.
"""
# check if the model is fit
self._check_is_fit()
Expand Down Expand Up @@ -1453,6 +1453,26 @@ def fit(
- If the mask is a `FeaturePytree`, then `"feature_name"` is a predictor of neuron `j` if
`feature_mask["feature_name"][j] == 1`.
Examples
--------
>>> # Generate sample data
>>> import jax.numpy as jnp
>>> import numpy as np
>>> from nemos.glm import PopulationGLM
>>> # Define predictors (X), weights, and neural activity (y)
>>> num_samples, num_features, num_neurons = 100, 3, 2
>>> X = np.random.normal(size=(num_samples, num_features))
>>> weights = np.array([[ 0.5, 0. ], [-0.5, -0.5], [ 0. , 1. ]])
>>> y = np.random.poisson(np.exp(X.dot(weights)))
>>> # Define a feature mask, shape (num_features, num_neurons)
>>> feature_mask = jnp.array([[1, 0], [1, 1], [0, 1]])
>>> # Create and fit the model
>>> model = PopulationGLM(feature_mask=feature_mask).fit(X, y)
>>> print(model.coef_.shape)
(3, 2)
"""
return super().fit(X, y, init_params)

Expand Down

0 comments on commit 5aba7d6

Please sign in to comment.