Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Glm class restructure #41

Merged
merged 261 commits into from
Dec 14, 2023
Merged
Changes from 1 commit
Commits
Show all changes
261 commits
Select commit Hold shift + click to select a range
5507c63
fixed naming and improved docstrings
BalzaniEdoardo Aug 21, 2023
b9cb88f
initialize parameters to none
BalzaniEdoardo Aug 21, 2023
451be3a
improved docstrings, started off tests
BalzaniEdoardo Aug 21, 2023
e02f722
test glm started
BalzaniEdoardo Aug 21, 2023
94efb2c
updated tests
BalzaniEdoardo Aug 23, 2023
d88ee96
test completed, need docstrings
BalzaniEdoardo Aug 23, 2023
f40b1e1
linted and removed dandi test
BalzaniEdoardo Aug 23, 2023
07c9e56
linted src
BalzaniEdoardo Aug 23, 2023
bcd7cf9
improved description of simulate
BalzaniEdoardo Aug 23, 2023
5fafb5d
improved description of simulate
BalzaniEdoardo Aug 23, 2023
28db00c
bugfixed config
BalzaniEdoardo Aug 23, 2023
405b17f
improved logic of fit, add check nans
BalzaniEdoardo Aug 24, 2023
ba07cff
added short docstrings to tests
BalzaniEdoardo Aug 24, 2023
3f4cab2
tested model_base
BalzaniEdoardo Aug 24, 2023
f2ac53e
started off the developers' notes?
BalzaniEdoardo Aug 25, 2023
ff2a549
base_model restructured and modified note
BalzaniEdoardo Aug 25, 2023
b9b5b85
added docstrings to model_base.py
BalzaniEdoardo Aug 25, 2023
096b306
massive refractoring, makes sure that score and simulate has a GLM-ty…
BalzaniEdoardo Aug 25, 2023
8d2622c
fixed example in Error
BalzaniEdoardo Aug 25, 2023
5d80e9d
linted all
BalzaniEdoardo Aug 25, 2023
6c3da95
switched residual deviance to public
BalzaniEdoardo Aug 25, 2023
a342808
bugfixed tests
BalzaniEdoardo Aug 26, 2023
91680e1
refractor classes names
BalzaniEdoardo Aug 26, 2023
c8dd031
bugfixed test_base
BalzaniEdoardo Aug 26, 2023
c56fd7e
added an end-to-end fit
BalzaniEdoardo Aug 28, 2023
7e21a36
improved docstrings of base_class
BalzaniEdoardo Aug 28, 2023
02be1da
improved notes
BalzaniEdoardo Aug 28, 2023
ffd640b
added scheme of model classes
BalzaniEdoardo Aug 29, 2023
9c6e135
linted glm.py, removed basic_test.py
BalzaniEdoardo Aug 29, 2023
80b0676
reviewed glm note
BalzaniEdoardo Aug 29, 2023
a0bafdc
update docstring for simulate
BalzaniEdoardo Aug 29, 2023
d349b0d
update docstring for simulate
BalzaniEdoardo Aug 29, 2023
5ff37e1
improved structure
BalzaniEdoardo Aug 29, 2023
6a5840e
imrpoved docstrings
BalzaniEdoardo Aug 29, 2023
29a8dfd
linted
BalzaniEdoardo Aug 29, 2023
32277b5
added device check on fit
BalzaniEdoardo Aug 31, 2023
9567546
Update 02-base_class.md
BalzaniEdoardo Sep 6, 2023
d6458b2
improved exception messages
BalzaniEdoardo Sep 6, 2023
31d1b85
flexible data_type
BalzaniEdoardo Sep 6, 2023
a9c2a9e
linted
BalzaniEdoardo Sep 6, 2023
f61b1f7
data_type as parameter
BalzaniEdoardo Sep 6, 2023
3034a9a
linted
BalzaniEdoardo Sep 6, 2023
aec419e
improved docstrings
BalzaniEdoardo Sep 6, 2023
0d5e65a
pydoctest linting
BalzaniEdoardo Sep 6, 2023
1c54339
pydoctest linting
BalzaniEdoardo Sep 6, 2023
c415c68
started loss and obs noise
BalzaniEdoardo Sep 7, 2023
0912e91
added tests
BalzaniEdoardo Sep 9, 2023
7e613d9
added comprehensive testing
BalzaniEdoardo Sep 9, 2023
7cfd5c2
add code of conduct
BalzaniEdoardo Sep 9, 2023
b22de56
fixed typing
BalzaniEdoardo Sep 10, 2023
4eddb67
use jax config as default dtype
BalzaniEdoardo Sep 10, 2023
79b5997
make sure the mask is configured before the runner is called
BalzaniEdoardo Sep 10, 2023
7504215
removed unused deps
BalzaniEdoardo Sep 10, 2023
f341c86
glm edited after discussion
BalzaniEdoardo Sep 11, 2023
f56bd4f
fixed example
BalzaniEdoardo Sep 11, 2023
779bf2d
removed firing and spiking
BalzaniEdoardo Sep 12, 2023
6d43fb4
start test refractoring
BalzaniEdoardo Sep 13, 2023
3fc04b7
fixed linking of noise model
BalzaniEdoardo Sep 13, 2023
db8c9c0
imporoved text
BalzaniEdoardo Sep 13, 2023
93c3cb6
refractor of test to be continued
BalzaniEdoardo Sep 13, 2023
7fc6d6e
refractored glm tests with helper function
BalzaniEdoardo Sep 14, 2023
f87e01d
removed data_type, improved docstrings, fixed tests
BalzaniEdoardo Sep 18, 2023
0c2d173
fixed deprecations
BalzaniEdoardo Sep 19, 2023
788f3d4
added tests for corner cases
BalzaniEdoardo Sep 19, 2023
ee00d13
set solver_name and solver_kwargs as propertied
BalzaniEdoardo Sep 19, 2023
263c530
set noise model link func as a property
BalzaniEdoardo Sep 19, 2023
3cfc0b0
linted
BalzaniEdoardo Sep 19, 2023
e430c35
bugfixed tests (compiled jax funcs start with jaxlib)
BalzaniEdoardo Sep 19, 2023
55ec5e3
added tests for unreg solver
BalzaniEdoardo Sep 19, 2023
064222d
added unreg tests and pointed to src for coverage
BalzaniEdoardo Sep 19, 2023
b37e9b5
added unreg tests and pointed to src for coverage
BalzaniEdoardo Sep 19, 2023
01c86d0
improved docs
BalzaniEdoardo Sep 20, 2023
753a64d
edited 02-base_class.md note
BalzaniEdoardo Sep 20, 2023
29eb4ba
edited note 03-glm.md
BalzaniEdoardo Sep 20, 2023
23e0557
edited note 03-glm.md
BalzaniEdoardo Sep 21, 2023
6e83ede
bugfixed plots
BalzaniEdoardo Sep 21, 2023
5f8b78a
first pass to the noise model dev note
BalzaniEdoardo Sep 21, 2023
b30b48f
completed note 04-noise_model.md
BalzaniEdoardo Sep 21, 2023
ad99760
started new note on solver
BalzaniEdoardo Sep 22, 2023
3211e8a
minor edits and grammar checks
BalzaniEdoardo Sep 22, 2023
8e39d5e
added a glossary
BalzaniEdoardo Sep 22, 2023
2b2962d
edited note
BalzaniEdoardo Sep 22, 2023
0aec19d
edited notes, added hyeperlinks
BalzaniEdoardo Sep 22, 2023
4e3f51a
linted
BalzaniEdoardo Sep 22, 2023
3767617
added one line docstrings for abstract methods
BalzaniEdoardo Sep 22, 2023
34e2570
linted docstrings
BalzaniEdoardo Sep 22, 2023
c9df746
linted docstrings with pydocstyle
BalzaniEdoardo Sep 22, 2023
fc962a5
mypy fixes
BalzaniEdoardo Sep 22, 2023
1cf9d07
changed notes order
BalzaniEdoardo Sep 22, 2023
b0a9618
fixed examples
BalzaniEdoardo Sep 22, 2023
48e9e81
restart the ci
BalzaniEdoardo Sep 22, 2023
00b2401
show inherited methods in docs
BalzaniEdoardo Sep 22, 2023
6dee6af
merged main
BalzaniEdoardo Sep 22, 2023
0cd45f9
added new modules to index
BalzaniEdoardo Sep 26, 2023
44e5558
Update docs/developers_notes/02-base_class.md
BalzaniEdoardo Oct 26, 2023
b8ca932
Update docs/developers_notes/02-base_class.md
BalzaniEdoardo Oct 26, 2023
65b2df4
changed CI to edit with isort and not just check
BalzaniEdoardo Oct 26, 2023
faf0e60
removed index
BalzaniEdoardo Oct 26, 2023
89112fd
replaces yaml with json. removed yaml dependency
BalzaniEdoardo Oct 26, 2023
226a752
Update mkdocs.yml
BalzaniEdoardo Oct 26, 2023
23f1cc5
removed leftover filter in mkdocstrings
BalzaniEdoardo Oct 26, 2023
099d517
testing coverage
BalzaniEdoardo Oct 30, 2023
b363705
Merge branch 'glm_class_restructure' of github.com:flatironinstitute/…
BalzaniEdoardo Oct 30, 2023
ede53f1
switched to json in tests
BalzaniEdoardo Oct 30, 2023
e43991c
requirement updated
BalzaniEdoardo Oct 30, 2023
e336d9e
bugfix version
BalzaniEdoardo Oct 30, 2023
a4eb0eb
coverage over src
BalzaniEdoardo Oct 30, 2023
c868f68
changed function naming
BalzaniEdoardo Oct 30, 2023
e434236
refractored name to Base
BalzaniEdoardo Oct 30, 2023
6c39aa3
removed hidden attributes
BalzaniEdoardo Oct 30, 2023
6020290
refer to jax
BalzaniEdoardo Oct 30, 2023
10af2fe
moved conversion to utils.py
BalzaniEdoardo Oct 30, 2023
379ae4c
moved check invalid entry to utils
BalzaniEdoardo Oct 31, 2023
3229c05
Update src/neurostatslib/base_class.py
BalzaniEdoardo Oct 31, 2023
255d969
Update src/neurostatslib/glm.py
BalzaniEdoardo Oct 31, 2023
489853f
linted
BalzaniEdoardo Oct 31, 2023
89df69c
changed mock tests and changed get_params
BalzaniEdoardo Oct 31, 2023
914ce1b
fixed typo
BalzaniEdoardo Oct 31, 2023
23310cf
create a check array function
BalzaniEdoardo Oct 31, 2023
94d5523
preprocess func private
BalzaniEdoardo Oct 31, 2023
b50017e
fixed tests
BalzaniEdoardo Oct 31, 2023
759a23d
renamed paramters in simulate preproc
BalzaniEdoardo Oct 31, 2023
9fef6c8
linted
BalzaniEdoardo Oct 31, 2023
f10e4e9
linted
BalzaniEdoardo Oct 31, 2023
3152065
added test for noise model link
BalzaniEdoardo Oct 31, 2023
9b55807
Update src/neurostatslib/noise_model.py
BalzaniEdoardo Oct 31, 2023
bbe8213
Update src/neurostatslib/noise_model.py
BalzaniEdoardo Oct 31, 2023
0f7d67b
linted
BalzaniEdoardo Oct 31, 2023
293172f
Merge branch 'glm_class_restructure' of github.com:flatironinstitute/…
BalzaniEdoardo Oct 31, 2023
6cc4822
renamed emission_probability
BalzaniEdoardo Oct 31, 2023
2aae9ef
Update src/neurostatslib/noise_model.py
BalzaniEdoardo Nov 1, 2023
19456b1
Update src/neurostatslib/glm.py
BalzaniEdoardo Nov 1, 2023
a1f1655
Merge branch 'glm_class_restructure' of github.com:flatironinstitute/…
BalzaniEdoardo Nov 1, 2023
aa454fb
Update src/neurostatslib/glm.py
BalzaniEdoardo Nov 1, 2023
f74d3ab
Update src/neurostatslib/glm.py
BalzaniEdoardo Nov 1, 2023
b29b6d7
Update src/neurostatslib/solver.py
BalzaniEdoardo Nov 1, 2023
dbdca86
Update src/neurostatslib/glm.py
BalzaniEdoardo Nov 1, 2023
881d503
Update src/neurostatslib/solver.py
BalzaniEdoardo Nov 1, 2023
bab3a9b
Merge branch 'glm_class_restructure' of github.com:flatironinstitute/…
BalzaniEdoardo Nov 1, 2023
c340dd2
linted
BalzaniEdoardo Nov 1, 2023
448d370
removed mask from proximal operator
BalzaniEdoardo Nov 1, 2023
61a2511
fixed solver mask
BalzaniEdoardo Nov 1, 2023
0183d9e
removed the renaming
BalzaniEdoardo Nov 1, 2023
ab3cd99
removed the any reference to noise
BalzaniEdoardo Nov 1, 2023
2991896
allow a loss function with 3 arguments with any variable names
BalzaniEdoardo Nov 2, 2023
2fdd66f
linted
BalzaniEdoardo Nov 2, 2023
5d43790
grammar fixed
BalzaniEdoardo Nov 2, 2023
7c12183
changed description of check func
BalzaniEdoardo Nov 2, 2023
27582cc
generalized get_runner to accept any inputs
BalzaniEdoardo Nov 3, 2023
bdefe97
simplified get_runner
BalzaniEdoardo Nov 3, 2023
86c4df8
simplified get_runner
BalzaniEdoardo Nov 3, 2023
8af2a21
solver set
BalzaniEdoardo Nov 3, 2023
b27cd31
modified glm moving the checks to the setter
BalzaniEdoardo Nov 3, 2023
3baaefb
added logistic as example
BalzaniEdoardo Nov 3, 2023
d3664c3
improved docstring for score
BalzaniEdoardo Nov 3, 2023
10255c6
bugfixed deviance
BalzaniEdoardo Nov 3, 2023
ecb069e
renamed deviance
BalzaniEdoardo Nov 3, 2023
662dfbc
moved raise in the if close
BalzaniEdoardo Nov 4, 2023
3c0fb3f
commented scanf
BalzaniEdoardo Nov 4, 2023
dbd3bbd
improved scanf algorithm
BalzaniEdoardo Nov 6, 2023
5151f7a
improved comments
BalzaniEdoardo Nov 6, 2023
b7aaa25
add test loss is callable
BalzaniEdoardo Nov 6, 2023
355199d
linted
BalzaniEdoardo Nov 6, 2023
dc02446
completed first round code revisions
BalzaniEdoardo Nov 7, 2023
c838e49
refractor parameter names
BalzaniEdoardo Nov 7, 2023
9d7ccd2
added a warning on the demo
BalzaniEdoardo Nov 7, 2023
7570908
checked conversion to nemos
BalzaniEdoardo Nov 7, 2023
2b87edc
no change
BalzaniEdoardo Nov 7, 2023
13135ce
removed refs to neurostatslib
BalzaniEdoardo Nov 7, 2023
14a17b8
removed refs to neurostatslib
BalzaniEdoardo Nov 7, 2023
b315f46
removed refs to neurostatslib
BalzaniEdoardo Nov 7, 2023
1f408fa
removed gen lin models and linted
BalzaniEdoardo Nov 7, 2023
77065a0
linted
BalzaniEdoardo Nov 7, 2023
b902219
removed gen-lin-mod
BalzaniEdoardo Nov 7, 2023
3aee269
added dev deps
BalzaniEdoardo Nov 7, 2023
cb381ff
commit pyproject.toml
BalzaniEdoardo Nov 7, 2023
db62140
linted
BalzaniEdoardo Nov 7, 2023
724522f
fixed links
BalzaniEdoardo Nov 7, 2023
15f062c
fixed links
BalzaniEdoardo Nov 7, 2023
1ef0eda
linted
BalzaniEdoardo Nov 7, 2023
a2a5cb7
fixed spacing
BalzaniEdoardo Nov 8, 2023
4ea0276
fixed the shapes
BalzaniEdoardo Nov 8, 2023
2a389cf
removed unused funcs
BalzaniEdoardo Nov 21, 2023
252b78b
fixed test
BalzaniEdoardo Nov 21, 2023
4592cfd
removed convert utils
BalzaniEdoardo Nov 21, 2023
38d6d51
fixed tests exceptions
BalzaniEdoardo Nov 21, 2023
a2981bc
fixed tests exceptions
BalzaniEdoardo Nov 21, 2023
59393f9
fixed regex tests
BalzaniEdoardo Nov 21, 2023
d40a272
fixed regex tests glm
BalzaniEdoardo Nov 21, 2023
3b43e04
fixed hyperlink
BalzaniEdoardo Nov 22, 2023
84e8925
improved _score docstrings
BalzaniEdoardo Nov 22, 2023
9debb46
added 2 types of pr2
BalzaniEdoardo Nov 24, 2023
2bdd293
float_eps removed
BalzaniEdoardo Nov 27, 2023
06dc039
moved multi-array device put to utils.py
BalzaniEdoardo Nov 27, 2023
ef170e3
improved refs
BalzaniEdoardo Nov 27, 2023
3a0e990
refractored names
BalzaniEdoardo Nov 27, 2023
412727e
fixed tests
BalzaniEdoardo Nov 27, 2023
5ec8430
improved refs to solver
BalzaniEdoardo Nov 27, 2023
5665711
refractor note solver part1
BalzaniEdoardo Nov 27, 2023
266c077
refractor module name
BalzaniEdoardo Nov 28, 2023
7362b0c
improved docstring proximal operator
BalzaniEdoardo Nov 28, 2023
1b2e055
removed ## on module docstrings
BalzaniEdoardo Nov 29, 2023
62be95f
removed note on prox op, updated baseclass
BalzaniEdoardo Nov 29, 2023
8c61fd1
removed get_runner and use super().instantiate_solver instead
BalzaniEdoardo Nov 29, 2023
46c2a58
edited note
BalzaniEdoardo Nov 29, 2023
3d1c113
switched to fixture
BalzaniEdoardo Nov 30, 2023
eb0a677
started refractoring tests
BalzaniEdoardo Nov 30, 2023
c9d6f0a
refractored test_fit_weights_dimensionality
BalzaniEdoardo Nov 30, 2023
42e8a2e
refractored test_fit_intercepts_dimensionality
BalzaniEdoardo Nov 30, 2023
c3e7b29
refractored test_fit_init_params_type
BalzaniEdoardo Nov 30, 2023
332a7eb
refractored test_fit_n_neuron_match_weights
BalzaniEdoardo Nov 30, 2023
90bbc87
refractored test_fit_n_neuron_match_baseline_rate
BalzaniEdoardo Nov 30, 2023
6e49aa6
refractored test_fit_n_neuron_match_baseline_rate
BalzaniEdoardo Nov 30, 2023
2a043f0
refractored test_fit_n_neuron_match_x
BalzaniEdoardo Nov 30, 2023
2468b73
refractored test_fit_n_neuron_match_y
BalzaniEdoardo Nov 30, 2023
4b6565b
refractored test_fit_x_dimensionality
BalzaniEdoardo Nov 30, 2023
0513540
refractored test_fit_y_dimensionality
BalzaniEdoardo Nov 30, 2023
fb862da
refractored test_fit_n_feature_consistency_weights
BalzaniEdoardo Nov 30, 2023
88fe8c8
refractored test_fit_n_feature_consistency_x
BalzaniEdoardo Nov 30, 2023
34b9b84
refractored test_fit_time_points_x
BalzaniEdoardo Nov 30, 2023
1fc86b2
test_fit_time_points_y
BalzaniEdoardo Nov 30, 2023
0071c07
test_score_n_neuron_match_x
BalzaniEdoardo Nov 30, 2023
71e6976
test_score_n_neuron_match_y
BalzaniEdoardo Nov 30, 2023
56e86b0
test_score_x_dimensionality
BalzaniEdoardo Nov 30, 2023
1b2215f
completed refractoring
BalzaniEdoardo Nov 30, 2023
2786607
linted everything
BalzaniEdoardo Nov 30, 2023
85de28f
fixed naming
BalzaniEdoardo Nov 30, 2023
41f3771
linted glm
BalzaniEdoardo Nov 30, 2023
29aab97
removed sklearn dep
BalzaniEdoardo Nov 30, 2023
979b3b4
removed unused method
BalzaniEdoardo Dec 11, 2023
899ca98
refractored names
BalzaniEdoardo Dec 11, 2023
d5c220e
undoes some unhelpful python black ormatting
billbrod Dec 11, 2023
2e2c071
fixed docstrings
BalzaniEdoardo Dec 11, 2023
9645f9d
fix docstring
billbrod Dec 11, 2023
1ee0de4
choen -> cohen
billbrod Dec 11, 2023
4cc2d34
added statsmodels equivalence in docstrings
BalzaniEdoardo Dec 11, 2023
e76ce1a
Merge branch 'glm_class_restructure' of github.com:flatironinstitute/…
BalzaniEdoardo Dec 11, 2023
40249ea
resolved conflicts
BalzaniEdoardo Dec 11, 2023
d016665
linted flake8
BalzaniEdoardo Dec 11, 2023
e9bd25c
generate parameters for simulaiton
BalzaniEdoardo Dec 12, 2023
4679dc2
removed json in docs
BalzaniEdoardo Dec 12, 2023
5d32df9
modified the glm demo
BalzaniEdoardo Dec 13, 2023
2930373
added testing
BalzaniEdoardo Dec 14, 2023
46092ed
fixed test by removing dep on json
BalzaniEdoardo Dec 14, 2023
2a815b4
added test of correctness for the lsq
BalzaniEdoardo Dec 14, 2023
58a5aba
added raises
BalzaniEdoardo Dec 14, 2023
0d1207c
Update src/nemos/simulation.py
BalzaniEdoardo Dec 14, 2023
c2dd948
Update src/nemos/simulation.py
BalzaniEdoardo Dec 14, 2023
3034922
Update src/nemos/simulation.py
BalzaniEdoardo Dec 14, 2023
2a5259b
Merge branch 'glm_class_restructure' of github.com:flatironinstitute/…
BalzaniEdoardo Dec 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
edited note 03-glm.md
BalzaniEdoardo committed Sep 20, 2023
commit 29eb4bad1eda45e9290a7b559f15a5263d590813
101 changes: 40 additions & 61 deletions docs/developers_notes/03-glm.md
Original file line number Diff line number Diff line change
@@ -2,95 +2,74 @@

## Introduction

The `neurostatslib.glm` module implements variations of Generalized Linear Models (GLMs) classes.

At this stage, the module consists of two primary classes:

1. **`_BaseGLM`:** An abstract class serving as the backbone for building GLMs.
2. **`PoissonGLM`:** A concrete implementation of the GLM for Poisson-distributed data.
Generalized Linear Models (GLM) provide a flexible framework for modeling a variety of data types while establishing a relationship between multiple predictors and a response variable. A GLM extends the traditional linear regression by allowing for response variables that have error distribution models other than a normal distribution, such as binomial or Poisson distributions.

Our design aligns with the `scikit-learn` API. This ensures that our GLM classes integrate seamlessly with the robust `scikit-learn` pipeline and its cross-validation capabilities.
The `neurostatslib.glm` module currently offers implementations of two GLM classes:

## The class `_BaseGLM`
1. **`GLM`:** A direct implementation of a feedforward GLM.
2. **`RecurrentGLM`:** An implementation of a recurrent GLM. This class inherits from `GLM` and redefines the `simulate` method to generate spikes akin to a recurrent neural network.

Designed with `scikit-learn` compatibility in mind, `_BaseGLM` provides the common computations and functionalities needed by the diverse `GLM` subclasses.
Our design is harmonized with the `scikit-learn` API, facilitating seamless integration of our GLM classes with the well-established `scikit-learn` pipeline and its cross-validation tools.

### Inheritance

The `_BaseGLM` inherits attributes and methods from the `_BaseRegressor`, as detailed in the [`base_class` module](02-base_class.md). This grants `_BaseGLM` a toolkit for managing and verifying model inputs. Leveraging the inherited abstraction, all GLM subclasses must explicitly define the `fit`, `predict`, `score`, and `simulate` methods, ensuring alignment with the `scikit-learn` framework.

### Attributes

- **`solver`**: The optimization solver from jaxopt.
- **`solver_state`**: Represents the current state of the solver.
- **`basis_coeff_`**: Holds the solution for spike basis coefficients after the model has been fitted. Initialized to `None` at class instantiation.
- **`baseline_link_fr`**: Contains the bias terms' solutions after fitting. Initialized to `None` at class instantiation.
- **`kwargs`**: Other keyword arguments, like regularization hyperparameters.


### Private Methods
The classes provided here are modular by design offering a standard foundation for any GLM variant.

- **`_check_is_fit`**: Ensures the instance has been fitted. This check is implemented here and not in `_BaseRegressor` because the model parameters are likely to be GLM specific.
- **`_predict`**: Forecasts firing rates based on predictors and parameters.
- **`_pseudo_r2`**: Computes the Pseudo-$R^2$ for a GLM, giving insight into the model's fit relative to a null model.
- **`_safe_predict`**: Validates the model's fit status and input consistency before calculating mean rates using the `_predict` method.
- **`_safe_score`**: Scores the predicted firing rates against target spike counts. Can compute either the GLM mean log-likelihood or the pseudo-$R^2$.
- **`_safe_fit`**: Fit the GLM to the neural activity. Verifies input conformity, then leverages the `jaxopt` optimizer on the designated loss function (provided by the concrete GLM subclass).
- **`_safe_simulate`**: Simulates spike trains using the GLM as a recurrent network. It projects neural activity into the future using the fitted parameters of the GLM. The function can simulate activity based on both historical spike activity and external feedforward inputs, such as convolved currents, light intensities, etc.
Instantiating a specific GLM simply requires providing an observation noise model (Gamma, Poisson, etc.) and a regularization strategies (Ridge, Lasso, etc.) during initialization.

![Title](GLM_scheme.jpg){ width="512" }
<figure markdown>
<figcaption>Schematic of the module interactions.</figcaption>
</figure>

!!! note
The introduction of `_safe_predict`, `_safe_fit`, `_safe_score` and `_safe_simulate` offers the following benefits:

1. It eliminates the need for subclasses to redo checks in their `fit`, `score` and `simulate` methods, leading to concise code.
2. The methods `predict`, `score`, `fit`, and `simulate` must be defined by subclasses due to their abstract nature in `_BaseRegressor`. This ensures subclass-specific docstrings for public methods.

While `predict` is common to any GLM, we explicitly omit its implementation in `_BaseGLM` so that the method will be documented in the `Code References` under each concrete class.
## The Concrete Class `GLM`

### Abstract Methods
Besides the methods acquired from `_BaseRegressor`, `_BaseGLM` introduces:
The `GLM` class provides a direct implementation of the GLM model and is designed with `scikit-learn` compatibility in mind.

- **`residual_deviance`**: Computes a GLM's residual deviance. The deviance, on par with the likelihood, is model specific.
### Inheritance

!!! note
The residual deviance can be formulated as a function of log-likelihood. Although a concrete `_BaseGLM` implementation is feasible, subclass-specific implementations might offer increased robustness or efficiency.
`GLM` inherits from `BaseRegressor`. This inheritance mandates the direct implementation of methods like `predict`, `fit`, `score`, and `simulate`.

## The Concrete Class `PoissonGLM`
### Attributes

The class `PoissonGLM` is a concrete implementation of the un-regularized Poisson GLM model.
- **`solver`**: Refers to the optimization solver - an object of the `neurostatslib.solver.Solver` type. It uses the `jaxopt` solver to minimize the (penalized) negative log-likelihood of the GLM.
- **`noise_model`**: Represents the GLM noise model, which is an object of the `neurostatlib.noise_model.NoiseModel` type. This model determines the log-likelihood and the emission probability mechanism for the `GLM`.
- **`basis_coeff_`**: Stores the solution for spike basis coefficients as `jax.ndarray` after the fitting process. It is initialized as `None` during class instantiation.
- **`baseline_link_fr_`**: Stores the bias terms' solutions as `jax.ndarray` after the fitting process. It is initialized as `None` during class instantiation.
- **`solver_state`**: Indicates the solver's state. For specific solver states, refer to the [`jaxopt` documentation](https://jaxopt.github.io/stable/index.html#).

### Inheritance
### Public Methods

`PoissonGLM` inherits from `_BaseGLM`, which provides methods for predicting firing rates and "safe" methods to score and simulate spike trains. Inheritance enforces the concrete implementation of `fit`, `score`, `simulate`, and `residual_deviance`.
- **`predict`**: Validates input and computes the mean rates of the `GLM` by invoking the inverse-link function of the `noise_model` attribute.
- **`score`**: Validates input and assesses the Poisson GLM using either log-likelihood or pseudo-$R^2$. This method uses the `noise_model` to determine log-likelihood or pseudo-$R^2$.
- **`fit`**: Validates input and aligns the Poisson GLM with spike train data. It leverages the `noise_model` and `solver` to define the model's loss function and instantiate the solver.
- **`simulate`**: Simulates spike trains using the GLM as a feedforward network, invoking the `noise_model.emission_probability` method for emission probability.

### Attributes
### Private Methods

- **`solver`**: The optimization solver from jaxopt.
- **`solver_state`**: Represents the current state of the solver.
- **`basis_coeff_`**: Holds the solution for spike basis coefficients after the model has been fitted. Initialized to `None` at class instantiation.
- **`baseline_link_fr`**: Contains the bias terms' solutions after fitting. Initialized to `None` at class instantiation.
- **`_predict`**: Forecasts rates based on current model parameters and the inverse-link function of the `noise_model`.
- **`_score`**: Determines the Poisson negative log-likelihood, excluding normalization constants.
- **`_check_is_fit`**: Validates whether the model has been appropriately fit by ensuring model parameters are set. If not, a `NotFittedError` is raised.


### Public Methods
## The Concrete Class `RecurrentGLM`

- **`predict`**: Calculates mean rates by invoking the `_safe_predict` method of `_BaseGLM`.
- **`score`**: Scores the Poisson GLM using either log-likelihood or pseudo-$R^2$. It invokes the parent `_safe_score` method to validate input and parameters.
- **`fit`**: Fits the Poisson GLM to align with spike train data by invoking `_safe_fit` and setting Poisson negative log-likelihood as the loss function.
- **`residual_deviance`**: Computes the residual deviance for each Poisson model observation, given predicted rates and spike counts.
- **`simulate`**: Simulates spike trains using the GLM as a recurrent network, invoking `_safe_simulate` and setting `jax.random.poisson` as the emission probability mechanism.
The `RecurrentGLM` class is an extension of the `GLM`, designed to simulate models with recurrent connections. It inherits the `predict`, `fit`, and `score` methods from `GLM`, but provides its own implementation for the `simulate` method.

### Private Methods
### Overridden Methods

- **`_score`**: Computes the Poisson negative log-likelihood up to a normalization constant. This method is used to define the optimization loss function for the model.
- **`simulate`**: This method simulates spike trains, treating the GLM as a recurrent neural network. It utilizes the `noise_model.emission_probability` method to determine the emission probability.

## Contributor Guidelines

### Implementing Model Subclasses

To write a usable (i.e. concrete) GLM class you
When crafting a functional (i.e., concrete) GLM class:

- **Must** inherit `_BaseGLM` or any of its subclasses.
- **Must** implement the `fit`, `score`, `simulate`, and `residual_deviance` methods, either directly or through inheritance.
- **Should** invoke `_safe_fit`, `_safe_score`, and `_safe_simulate` within the `fit`, `score`, and `simulate` methods, respectively.
- **Should not** override `_safe_fit`, `_safe_score`, or `_safe_simulate`.
- **May** integrate supplementary parameter and input checks if mandated by the GLM subclass.
- **Must** inherit from `BaseRegressor` or one of its derivatives.
- **Must** realize the `predict`, `fit`, `score`, and `simulate` methods, either directly or through inheritance.
- **Should** incorporate a `noise_model` attribute of type `neurostatslib.noise_model.NoiseModel` to specify the link-function, emission probability, and likelihood.
- **Should** include a `solver` attribute of type `neurostatslib.solver.Solver` to establish the solver based on penalization type.
- **May** embed additional parameter and input checks if required by the specific GLM subclass.
Binary file added docs/developers_notes/GLM_scheme.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
@@ -7,6 +7,9 @@ theme:
primary: 'light blue' # The primary color palette for the theme
features:
- navigation.tabs # Enable navigation tabs feature for the theme
markdown_extensions:
- md_in_html
- admonition

plugins:
- search
24 changes: 7 additions & 17 deletions src/neurostatslib/glm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""GLM core module."""
from typing import Any, Literal, Optional, Tuple, Type, Union
from typing import Any, Literal, Optional, Tuple, Union

import jax
import jax.numpy as jnp
@@ -29,8 +29,6 @@ class GLM(BaseRegressor):
solver : Solver
Solver to use for model optimization. Defines the optimization algorithm and related parameters.
Default is Ridge regression with gradient descent.
billbrod marked this conversation as resolved.
Show resolved Hide resolved
**kwargs : Any
Additional keyword arguments.

Attributes
----------
@@ -42,8 +40,6 @@ class GLM(BaseRegressor):
Model baseline link firing rate parameters after fitting.
basis_coeff_ : jnp.ndarray or None
Basis coefficients for the model after fitting.
scale : float
Scale parameter for the noise model. It's 1.0 for Poisson and Gaussian.
solver_state : Any
State of the solver after fitting. May include details like optimization error.

@@ -65,7 +61,6 @@ def __init__(
self,
noise_model: nsm.NoiseModel = nsm.PoissonNoiseModel(),
solver: slv.Solver = slv.RidgeSolver("GradientDescent"),
**kwargs: Any,
):
super().__init__()

@@ -84,14 +79,12 @@ def __init__(
self.noise_model = noise_model
self.solver = solver

# initialize to None fit output
self.baseline_link_fr_ = None
self.basis_coeff_ = None
# scale parameter (=1 for poisson and Gaussian, needs to be estimated for Gamma)
# the estimate of scale does not affect the ML estimate of the other parameter
self.scale = 1.0
self.solver_state = None

def _check_is_fit(self): # check scale.
def _check_is_fit(self):
"""Ensure the instance has been fitted."""
if (self.basis_coeff_ is None) or (self.baseline_link_fr_ is None):
raise NotFittedError(
@@ -430,21 +423,19 @@ class GLMRecurrent(GLM):
The numerical data type for internal calculations. If not provided, it will be inferred
from the data during fitting.

Attributes
----------
- The attributes of `GLMRecurrent` are inherited from the parent `GLM` class, and might include
coefficients, fitted status, and other model-related attributes.

See Also
--------
[GLM](../glm/#neurostatslib.glm.GLM) : Base class for the generalized linear model.
BalzaniEdoardo marked this conversation as resolved.
Show resolved Hide resolved

Notes
-----
The recurrent GLM assumes that neural activity can be influenced by both feedforward
- The recurrent GLM assumes that neural activity can be influenced by both feedforward
inputs and the past activity of the same and other neurons. This makes it particularly
powerful for capturing the dynamics of neural networks where neurons are interconnected.

- The attributes of `GLMRecurrent` are inherited from the parent `GLM` class, and might include
billbrod marked this conversation as resolved.
Show resolved Hide resolved
coefficients, fitted status, and other model-related attributes.

Examples
--------
>>> # Initialize the recurrent GLM with default parameters
@@ -457,7 +448,6 @@ def __init__(
self,
noise_model: nsm.NoiseModel = nsm.PoissonNoiseModel(),
solver: slv.Solver = slv.RidgeSolver(),
data_type: Optional[Union[Type[jnp.float32], Type[jnp.float64]]] = None,
):
super().__init__(noise_model=noise_model, solver=solver)

26 changes: 24 additions & 2 deletions src/neurostatslib/noise_model.py
Original file line number Diff line number Diff line change
@@ -45,7 +45,7 @@ def __init__(self, inverse_link_function: Callable, **kwargs):
super().__init__(**kwargs)
self._check_inverse_link_function(inverse_link_function)
billbrod marked this conversation as resolved.
Show resolved Hide resolved
self._inverse_link_function = inverse_link_function
self._scale = None
self._scale = 1.

@property
def inverse_link_function(self):
@@ -58,6 +58,18 @@ def inverse_link_function(self, inverse_link_function: Callable):
self._check_inverse_link_function(inverse_link_function)
self._inverse_link_function = inverse_link_function

@property
def scale(self):
"""Getter for the scale parameter of the model."""
return self._scale

@scale.setter
def scale(self, value: Union[int, float]):
"""Setter for the scale parameter of the model."""
if not isinstance(value, (int, float)):
raise ValueError("The `scale` parameter must be of numeric type.")
billbrod marked this conversation as resolved.
Show resolved Hide resolved
self._scale = value

@staticmethod
def _check_inverse_link_function(inverse_link_function):
billbrod marked this conversation as resolved.
Show resolved Hide resolved
if not callable(inverse_link_function):
@@ -73,7 +85,7 @@ def _check_inverse_link_function(inverse_link_function):
)

@abc.abstractmethod
def negative_log_likelihood(self, firing_rate, y):
def negative_log_likelihood(self, predicted_rate, y):
r"""Compute the noise model negative log-likelihood.

This computes the negative log-likelihood of the predicted rates
@@ -136,6 +148,11 @@ def residual_deviance(self, predicted_rate: jnp.ndarray, spike_counts: jnp.ndarr
"""
pass

@abc.abstractmethod
def estimate_scale(self, predicted_rate: jnp.ndarray) -> float:
"""Estimate the scale parameter for the model."""
billbrod marked this conversation as resolved.
Show resolved Hide resolved
pass

def pseudo_r2(self, predicted_rate: jnp.ndarray, y: jnp.ndarray):
r"""Pseudo-R^2 calculation for a GLM.
billbrod marked this conversation as resolved.
Show resolved Hide resolved

@@ -297,3 +314,8 @@ def residual_deviance(
spike_counts * jnp.log(ratio) - (spike_counts - predicted_rate)
)
return resid_dev

def estimate_scale(self, predicted_rate: jnp.ndarray):
"""Assign 1 to the scale parameter of the Poisson model."""
self.scale = 1.