Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Glm class restructure #41

Merged
merged 261 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
261 commits
Select commit Hold shift + click to select a range
5507c63
fixed naming and improved docstrings
BalzaniEdoardo Aug 21, 2023
b9cb88f
initialize parameters to none
BalzaniEdoardo Aug 21, 2023
451be3a
improved docstrings, started off tests
BalzaniEdoardo Aug 21, 2023
e02f722
test glm started
BalzaniEdoardo Aug 21, 2023
94efb2c
updated tests
BalzaniEdoardo Aug 23, 2023
d88ee96
test completed, need docstrings
BalzaniEdoardo Aug 23, 2023
f40b1e1
linted and removed dandi test
BalzaniEdoardo Aug 23, 2023
07c9e56
linted src
BalzaniEdoardo Aug 23, 2023
bcd7cf9
improved description of simulate
BalzaniEdoardo Aug 23, 2023
5fafb5d
improved description of simulate
BalzaniEdoardo Aug 23, 2023
28db00c
bugfixed config
BalzaniEdoardo Aug 23, 2023
405b17f
improved logic of fit, add check nans
BalzaniEdoardo Aug 24, 2023
ba07cff
added short docstrings to tests
BalzaniEdoardo Aug 24, 2023
3f4cab2
tested model_base
BalzaniEdoardo Aug 24, 2023
f2ac53e
started off the developers' notes?
BalzaniEdoardo Aug 25, 2023
ff2a549
base_model restructured and modified note
BalzaniEdoardo Aug 25, 2023
b9b5b85
added docstrings to model_base.py
BalzaniEdoardo Aug 25, 2023
096b306
massive refractoring, makes sure that score and simulate has a GLM-ty…
BalzaniEdoardo Aug 25, 2023
8d2622c
fixed example in Error
BalzaniEdoardo Aug 25, 2023
5d80e9d
linted all
BalzaniEdoardo Aug 25, 2023
6c3da95
switched residual deviance to public
BalzaniEdoardo Aug 25, 2023
a342808
bugfixed tests
BalzaniEdoardo Aug 26, 2023
91680e1
refractor classes names
BalzaniEdoardo Aug 26, 2023
c8dd031
bugfixed test_base
BalzaniEdoardo Aug 26, 2023
c56fd7e
added an end-to-end fit
BalzaniEdoardo Aug 28, 2023
7e21a36
improved docstrings of base_class
BalzaniEdoardo Aug 28, 2023
02be1da
improved notes
BalzaniEdoardo Aug 28, 2023
ffd640b
added scheme of model classes
BalzaniEdoardo Aug 29, 2023
9c6e135
linted glm.py, removed basic_test.py
BalzaniEdoardo Aug 29, 2023
80b0676
reviewed glm note
BalzaniEdoardo Aug 29, 2023
a0bafdc
update docstring for simulate
BalzaniEdoardo Aug 29, 2023
d349b0d
update docstring for simulate
BalzaniEdoardo Aug 29, 2023
5ff37e1
improved structure
BalzaniEdoardo Aug 29, 2023
6a5840e
imrpoved docstrings
BalzaniEdoardo Aug 29, 2023
29a8dfd
linted
BalzaniEdoardo Aug 29, 2023
32277b5
added device check on fit
BalzaniEdoardo Aug 31, 2023
9567546
Update 02-base_class.md
BalzaniEdoardo Sep 6, 2023
d6458b2
improved exception messages
BalzaniEdoardo Sep 6, 2023
31d1b85
flexible data_type
BalzaniEdoardo Sep 6, 2023
a9c2a9e
linted
BalzaniEdoardo Sep 6, 2023
f61b1f7
data_type as parameter
BalzaniEdoardo Sep 6, 2023
3034a9a
linted
BalzaniEdoardo Sep 6, 2023
aec419e
improved docstrings
BalzaniEdoardo Sep 6, 2023
0d5e65a
pydoctest linting
BalzaniEdoardo Sep 6, 2023
1c54339
pydoctest linting
BalzaniEdoardo Sep 6, 2023
c415c68
started loss and obs noise
BalzaniEdoardo Sep 7, 2023
0912e91
added tests
BalzaniEdoardo Sep 9, 2023
7e613d9
added comprehensive testing
BalzaniEdoardo Sep 9, 2023
7cfd5c2
add code of conduct
BalzaniEdoardo Sep 9, 2023
b22de56
fixed typing
BalzaniEdoardo Sep 10, 2023
4eddb67
use jax config as default dtype
BalzaniEdoardo Sep 10, 2023
79b5997
make sure the mask is configured before the runner is called
BalzaniEdoardo Sep 10, 2023
7504215
removed unused deps
BalzaniEdoardo Sep 10, 2023
f341c86
glm edited after discussion
BalzaniEdoardo Sep 11, 2023
f56bd4f
fixed example
BalzaniEdoardo Sep 11, 2023
779bf2d
removed firing and spiking
BalzaniEdoardo Sep 12, 2023
6d43fb4
start test refractoring
BalzaniEdoardo Sep 13, 2023
3fc04b7
fixed linking of noise model
BalzaniEdoardo Sep 13, 2023
db8c9c0
imporoved text
BalzaniEdoardo Sep 13, 2023
93c3cb6
refractor of test to be continued
BalzaniEdoardo Sep 13, 2023
7fc6d6e
refractored glm tests with helper function
BalzaniEdoardo Sep 14, 2023
f87e01d
removed data_type, improved docstrings, fixed tests
BalzaniEdoardo Sep 18, 2023
0c2d173
fixed deprecations
BalzaniEdoardo Sep 19, 2023
788f3d4
added tests for corner cases
BalzaniEdoardo Sep 19, 2023
ee00d13
set solver_name and solver_kwargs as propertied
BalzaniEdoardo Sep 19, 2023
263c530
set noise model link func as a property
BalzaniEdoardo Sep 19, 2023
3cfc0b0
linted
BalzaniEdoardo Sep 19, 2023
e430c35
bugfixed tests (compiled jax funcs start with jaxlib)
BalzaniEdoardo Sep 19, 2023
55ec5e3
added tests for unreg solver
BalzaniEdoardo Sep 19, 2023
064222d
added unreg tests and pointed to src for coverage
BalzaniEdoardo Sep 19, 2023
b37e9b5
added unreg tests and pointed to src for coverage
BalzaniEdoardo Sep 19, 2023
01c86d0
improved docs
BalzaniEdoardo Sep 20, 2023
753a64d
edited 02-base_class.md note
BalzaniEdoardo Sep 20, 2023
29eb4ba
edited note 03-glm.md
BalzaniEdoardo Sep 20, 2023
23e0557
edited note 03-glm.md
BalzaniEdoardo Sep 21, 2023
6e83ede
bugfixed plots
BalzaniEdoardo Sep 21, 2023
5f8b78a
first pass to the noise model dev note
BalzaniEdoardo Sep 21, 2023
b30b48f
completed note 04-noise_model.md
BalzaniEdoardo Sep 21, 2023
ad99760
started new note on solver
BalzaniEdoardo Sep 22, 2023
3211e8a
minor edits and grammar checks
BalzaniEdoardo Sep 22, 2023
8e39d5e
added a glossary
BalzaniEdoardo Sep 22, 2023
2b2962d
edited note
BalzaniEdoardo Sep 22, 2023
0aec19d
edited notes, added hyeperlinks
BalzaniEdoardo Sep 22, 2023
4e3f51a
linted
BalzaniEdoardo Sep 22, 2023
3767617
added one line docstrings for abstract methods
BalzaniEdoardo Sep 22, 2023
34e2570
linted docstrings
BalzaniEdoardo Sep 22, 2023
c9df746
linted docstrings with pydocstyle
BalzaniEdoardo Sep 22, 2023
fc962a5
mypy fixes
BalzaniEdoardo Sep 22, 2023
1cf9d07
changed notes order
BalzaniEdoardo Sep 22, 2023
b0a9618
fixed examples
BalzaniEdoardo Sep 22, 2023
48e9e81
restart the ci
BalzaniEdoardo Sep 22, 2023
00b2401
show inherited methods in docs
BalzaniEdoardo Sep 22, 2023
6dee6af
merged main
BalzaniEdoardo Sep 22, 2023
0cd45f9
added new modules to index
BalzaniEdoardo Sep 26, 2023
44e5558
Update docs/developers_notes/02-base_class.md
BalzaniEdoardo Oct 26, 2023
b8ca932
Update docs/developers_notes/02-base_class.md
BalzaniEdoardo Oct 26, 2023
65b2df4
changed CI to edit with isort and not just check
BalzaniEdoardo Oct 26, 2023
faf0e60
removed index
BalzaniEdoardo Oct 26, 2023
89112fd
replaces yaml with json. removed yaml dependency
BalzaniEdoardo Oct 26, 2023
226a752
Update mkdocs.yml
BalzaniEdoardo Oct 26, 2023
23f1cc5
removed leftover filter in mkdocstrings
BalzaniEdoardo Oct 26, 2023
099d517
testing coverage
BalzaniEdoardo Oct 30, 2023
b363705
Merge branch 'glm_class_restructure' of github.com:flatironinstitute/…
BalzaniEdoardo Oct 30, 2023
ede53f1
switched to json in tests
BalzaniEdoardo Oct 30, 2023
e43991c
requirement updated
BalzaniEdoardo Oct 30, 2023
e336d9e
bugfix version
BalzaniEdoardo Oct 30, 2023
a4eb0eb
coverage over src
BalzaniEdoardo Oct 30, 2023
c868f68
changed function naming
BalzaniEdoardo Oct 30, 2023
e434236
refractored name to Base
BalzaniEdoardo Oct 30, 2023
6c39aa3
removed hidden attributes
BalzaniEdoardo Oct 30, 2023
6020290
refer to jax
BalzaniEdoardo Oct 30, 2023
10af2fe
moved conversion to utils.py
BalzaniEdoardo Oct 30, 2023
379ae4c
moved check invalid entry to utils
BalzaniEdoardo Oct 31, 2023
3229c05
Update src/neurostatslib/base_class.py
BalzaniEdoardo Oct 31, 2023
255d969
Update src/neurostatslib/glm.py
BalzaniEdoardo Oct 31, 2023
489853f
linted
BalzaniEdoardo Oct 31, 2023
89df69c
changed mock tests and changed get_params
BalzaniEdoardo Oct 31, 2023
914ce1b
fixed typo
BalzaniEdoardo Oct 31, 2023
23310cf
create a check array function
BalzaniEdoardo Oct 31, 2023
94d5523
preprocess func private
BalzaniEdoardo Oct 31, 2023
b50017e
fixed tests
BalzaniEdoardo Oct 31, 2023
759a23d
renamed paramters in simulate preproc
BalzaniEdoardo Oct 31, 2023
9fef6c8
linted
BalzaniEdoardo Oct 31, 2023
f10e4e9
linted
BalzaniEdoardo Oct 31, 2023
3152065
added test for noise model link
BalzaniEdoardo Oct 31, 2023
9b55807
Update src/neurostatslib/noise_model.py
BalzaniEdoardo Oct 31, 2023
bbe8213
Update src/neurostatslib/noise_model.py
BalzaniEdoardo Oct 31, 2023
0f7d67b
linted
BalzaniEdoardo Oct 31, 2023
293172f
Merge branch 'glm_class_restructure' of github.com:flatironinstitute/…
BalzaniEdoardo Oct 31, 2023
6cc4822
renamed emission_probability
BalzaniEdoardo Oct 31, 2023
2aae9ef
Update src/neurostatslib/noise_model.py
BalzaniEdoardo Nov 1, 2023
19456b1
Update src/neurostatslib/glm.py
BalzaniEdoardo Nov 1, 2023
a1f1655
Merge branch 'glm_class_restructure' of github.com:flatironinstitute/…
BalzaniEdoardo Nov 1, 2023
aa454fb
Update src/neurostatslib/glm.py
BalzaniEdoardo Nov 1, 2023
f74d3ab
Update src/neurostatslib/glm.py
BalzaniEdoardo Nov 1, 2023
b29b6d7
Update src/neurostatslib/solver.py
BalzaniEdoardo Nov 1, 2023
dbdca86
Update src/neurostatslib/glm.py
BalzaniEdoardo Nov 1, 2023
881d503
Update src/neurostatslib/solver.py
BalzaniEdoardo Nov 1, 2023
bab3a9b
Merge branch 'glm_class_restructure' of github.com:flatironinstitute/…
BalzaniEdoardo Nov 1, 2023
c340dd2
linted
BalzaniEdoardo Nov 1, 2023
448d370
removed mask from proximal operator
BalzaniEdoardo Nov 1, 2023
61a2511
fixed solver mask
BalzaniEdoardo Nov 1, 2023
0183d9e
removed the renaming
BalzaniEdoardo Nov 1, 2023
ab3cd99
removed the any reference to noise
BalzaniEdoardo Nov 1, 2023
2991896
allow a loss function with 3 arguments with any variable names
BalzaniEdoardo Nov 2, 2023
2fdd66f
linted
BalzaniEdoardo Nov 2, 2023
5d43790
grammar fixed
BalzaniEdoardo Nov 2, 2023
7c12183
changed description of check func
BalzaniEdoardo Nov 2, 2023
27582cc
generalized get_runner to accept any inputs
BalzaniEdoardo Nov 3, 2023
bdefe97
simplified get_runner
BalzaniEdoardo Nov 3, 2023
86c4df8
simplified get_runner
BalzaniEdoardo Nov 3, 2023
8af2a21
solver set
BalzaniEdoardo Nov 3, 2023
b27cd31
modified glm moving the checks to the setter
BalzaniEdoardo Nov 3, 2023
3baaefb
added logistic as example
BalzaniEdoardo Nov 3, 2023
d3664c3
improved docstring for score
BalzaniEdoardo Nov 3, 2023
10255c6
bugfixed deviance
BalzaniEdoardo Nov 3, 2023
ecb069e
renamed deviance
BalzaniEdoardo Nov 3, 2023
662dfbc
moved raise in the if close
BalzaniEdoardo Nov 4, 2023
3c0fb3f
commented scanf
BalzaniEdoardo Nov 4, 2023
dbd3bbd
improved scanf algorithm
BalzaniEdoardo Nov 6, 2023
5151f7a
improved comments
BalzaniEdoardo Nov 6, 2023
b7aaa25
add test loss is callable
BalzaniEdoardo Nov 6, 2023
355199d
linted
BalzaniEdoardo Nov 6, 2023
dc02446
completed first round code revisions
BalzaniEdoardo Nov 7, 2023
c838e49
refractor parameter names
BalzaniEdoardo Nov 7, 2023
9d7ccd2
added a warning on the demo
BalzaniEdoardo Nov 7, 2023
7570908
checked conversion to nemos
BalzaniEdoardo Nov 7, 2023
2b87edc
no change
BalzaniEdoardo Nov 7, 2023
13135ce
removed refs to neurostatslib
BalzaniEdoardo Nov 7, 2023
14a17b8
removed refs to neurostatslib
BalzaniEdoardo Nov 7, 2023
b315f46
removed refs to neurostatslib
BalzaniEdoardo Nov 7, 2023
1f408fa
removed gen lin models and linted
BalzaniEdoardo Nov 7, 2023
77065a0
linted
BalzaniEdoardo Nov 7, 2023
b902219
removed gen-lin-mod
BalzaniEdoardo Nov 7, 2023
3aee269
added dev deps
BalzaniEdoardo Nov 7, 2023
cb381ff
commit pyproject.toml
BalzaniEdoardo Nov 7, 2023
db62140
linted
BalzaniEdoardo Nov 7, 2023
724522f
fixed links
BalzaniEdoardo Nov 7, 2023
15f062c
fixed links
BalzaniEdoardo Nov 7, 2023
1ef0eda
linted
BalzaniEdoardo Nov 7, 2023
a2a5cb7
fixed spacing
BalzaniEdoardo Nov 8, 2023
4ea0276
fixed the shapes
BalzaniEdoardo Nov 8, 2023
2a389cf
removed unused funcs
BalzaniEdoardo Nov 21, 2023
252b78b
fixed test
BalzaniEdoardo Nov 21, 2023
4592cfd
removed convert utils
BalzaniEdoardo Nov 21, 2023
38d6d51
fixed tests exceptions
BalzaniEdoardo Nov 21, 2023
a2981bc
fixed tests exceptions
BalzaniEdoardo Nov 21, 2023
59393f9
fixed regex tests
BalzaniEdoardo Nov 21, 2023
d40a272
fixed regex tests glm
BalzaniEdoardo Nov 21, 2023
3b43e04
fixed hyperlink
BalzaniEdoardo Nov 22, 2023
84e8925
improved _score docstrings
BalzaniEdoardo Nov 22, 2023
9debb46
added 2 types of pr2
BalzaniEdoardo Nov 24, 2023
2bdd293
float_eps removed
BalzaniEdoardo Nov 27, 2023
06dc039
moved multi-array device put to utils.py
BalzaniEdoardo Nov 27, 2023
ef170e3
improved refs
BalzaniEdoardo Nov 27, 2023
3a0e990
refractored names
BalzaniEdoardo Nov 27, 2023
412727e
fixed tests
BalzaniEdoardo Nov 27, 2023
5ec8430
improved refs to solver
BalzaniEdoardo Nov 27, 2023
5665711
refractor note solver part1
BalzaniEdoardo Nov 27, 2023
266c077
refractor module name
BalzaniEdoardo Nov 28, 2023
7362b0c
improved docstring proximal operator
BalzaniEdoardo Nov 28, 2023
1b2e055
removed ## on module docstrings
BalzaniEdoardo Nov 29, 2023
62be95f
removed note on prox op, updated baseclass
BalzaniEdoardo Nov 29, 2023
8c61fd1
removed get_runner and use super().instantiate_solver instead
BalzaniEdoardo Nov 29, 2023
46c2a58
edited note
BalzaniEdoardo Nov 29, 2023
3d1c113
switched to fixture
BalzaniEdoardo Nov 30, 2023
eb0a677
started refractoring tests
BalzaniEdoardo Nov 30, 2023
c9d6f0a
refractored test_fit_weights_dimensionality
BalzaniEdoardo Nov 30, 2023
42e8a2e
refractored test_fit_intercepts_dimensionality
BalzaniEdoardo Nov 30, 2023
c3e7b29
refractored test_fit_init_params_type
BalzaniEdoardo Nov 30, 2023
332a7eb
refractored test_fit_n_neuron_match_weights
BalzaniEdoardo Nov 30, 2023
90bbc87
refractored test_fit_n_neuron_match_baseline_rate
BalzaniEdoardo Nov 30, 2023
6e49aa6
refractored test_fit_n_neuron_match_baseline_rate
BalzaniEdoardo Nov 30, 2023
2a043f0
refractored test_fit_n_neuron_match_x
BalzaniEdoardo Nov 30, 2023
2468b73
refractored test_fit_n_neuron_match_y
BalzaniEdoardo Nov 30, 2023
4b6565b
refractored test_fit_x_dimensionality
BalzaniEdoardo Nov 30, 2023
0513540
refractored test_fit_y_dimensionality
BalzaniEdoardo Nov 30, 2023
fb862da
refractored test_fit_n_feature_consistency_weights
BalzaniEdoardo Nov 30, 2023
88fe8c8
refractored test_fit_n_feature_consistency_x
BalzaniEdoardo Nov 30, 2023
34b9b84
refractored test_fit_time_points_x
BalzaniEdoardo Nov 30, 2023
1fc86b2
test_fit_time_points_y
BalzaniEdoardo Nov 30, 2023
0071c07
test_score_n_neuron_match_x
BalzaniEdoardo Nov 30, 2023
71e6976
test_score_n_neuron_match_y
BalzaniEdoardo Nov 30, 2023
56e86b0
test_score_x_dimensionality
BalzaniEdoardo Nov 30, 2023
1b2215f
completed refractoring
BalzaniEdoardo Nov 30, 2023
2786607
linted everything
BalzaniEdoardo Nov 30, 2023
85de28f
fixed naming
BalzaniEdoardo Nov 30, 2023
41f3771
linted glm
BalzaniEdoardo Nov 30, 2023
29aab97
removed sklearn dep
BalzaniEdoardo Nov 30, 2023
979b3b4
removed unused method
BalzaniEdoardo Dec 11, 2023
899ca98
refractored names
BalzaniEdoardo Dec 11, 2023
d5c220e
undoes some unhelpful python black ormatting
billbrod Dec 11, 2023
2e2c071
fixed docstrings
BalzaniEdoardo Dec 11, 2023
9645f9d
fix docstring
billbrod Dec 11, 2023
1ee0de4
choen -> cohen
billbrod Dec 11, 2023
4cc2d34
added statsmodels equivalence in docstrings
BalzaniEdoardo Dec 11, 2023
e76ce1a
Merge branch 'glm_class_restructure' of github.com:flatironinstitute/…
BalzaniEdoardo Dec 11, 2023
40249ea
resolved conflicts
BalzaniEdoardo Dec 11, 2023
d016665
linted flake8
BalzaniEdoardo Dec 11, 2023
e9bd25c
generate parameters for simulaiton
BalzaniEdoardo Dec 12, 2023
4679dc2
removed json in docs
BalzaniEdoardo Dec 12, 2023
5d32df9
modified the glm demo
BalzaniEdoardo Dec 13, 2023
2930373
added testing
BalzaniEdoardo Dec 14, 2023
46092ed
fixed test by removing dep on json
BalzaniEdoardo Dec 14, 2023
2a815b4
added test of correctness for the lsq
BalzaniEdoardo Dec 14, 2023
58a5aba
added raises
BalzaniEdoardo Dec 14, 2023
0d1207c
Update src/nemos/simulation.py
BalzaniEdoardo Dec 14, 2023
c2dd948
Update src/nemos/simulation.py
BalzaniEdoardo Dec 14, 2023
3034922
Update src/nemos/simulation.py
BalzaniEdoardo Dec 14, 2023
2a5259b
Merge branch 'glm_class_restructure' of github.com:flatironinstitute/…
BalzaniEdoardo Dec 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/developers_notes/03-observation_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ For subclasses derived from `Observations` to function correctly, they must impl

### Public Methods

- **pseudo_r2**: Method for computing the pseudo-$R^2$ of the model based on the residual deviance. There is no consensus definition for the pseudo-$R^2$, what we used here is the definition by Choen at al. 2003[^2].
- **pseudo_r2**: Method for computing the pseudo-$R^2$ of the model based on the residual deviance. There is no consensus definition for the pseudo-$R^2$, what we used here is the definition by Cohen at al. 2003[^2].


### Auxiliary Methods
Expand Down
10 changes: 5 additions & 5 deletions docs/developers_notes/04-regularizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Additionally, the class provides auxiliary methods for checking that the solver

### Auxiliary Methods

- **`_check_solver`**: This method ensures that the provided solver name is in the list of allowed optimizers for the specific `Regularizer` object. This is crucial for maintaining consistency and correctness in the solver's operation.
- **`_check_solver`**: This method ensures that the provided solver name is in the list of allowed solvers for the specific `Regularizer` object. This is crucial for maintaining consistency and correctness in the solver's operation.

- **`_check_solver_kwargs`**: This method checks if the provided keyword arguments are valid for the specified solver. This helps in catching and preventing potential errors in solver configuration.

Expand All @@ -50,7 +50,7 @@ The `UnRegularized` class extends the base `Regularizer` class and is designed s

### Attributes

- **`allowed_solvers`**: A list of string identifiers for the optimization algorithms that can be used with this solver class. The optimization methods listed here are specifically suitable for unregularized optimization problems.
- **`allowed_solvers`**: A list of string identifiers for the optimization solvers that can be used with this regularizer class. The optimization methods listed here are specifically suitable for unregularized optimization problems.

### Methods

Expand All @@ -72,7 +72,7 @@ The `Ridge` class extends the `Regularizer` class to handle optimization problem

### Attributes

- **`allowed_solvers`**: A list containing string identifiers of optimization algorithms compatible with Ridge regularization.
- **`allowed_solvers`**: A list containing string identifiers of optimization solvers compatible with Ridge regularization.

- **`regularizer_strength`**: A floating-point value determining the strength of the Ridge regularization. Higher values correspond to stronger regularization which tends to drive the model parameters towards zero.

Expand All @@ -97,7 +97,7 @@ optim_results = runner(init_params, exog_vars, endog_vars)
`ProxGradientRegularizer` class extends the `Regularizer` class to utilize the Proximal Gradient method for optimization. It leverages the `jaxopt` library's Proximal Gradient optimizer, introducing the functionality of a proximal operator.

### Attributes:
- **`allowed_solvers`**: A list containing string identifiers of optimization algorithms compatible with this solver, specifically the "ProximalGradient".
- **`allowed_solvers`**: A list containing string identifiers of optimization solvers compatible with this solver, specifically the "ProximalGradient".

### Methods:
- **`__init__`**: The constructor method for the `ProxGradientRegularizer` class. It accepts the name of the solver algorithm (`solver_name`), an optional dictionary of additional keyword arguments (`solver_kwargs`) for the solver, the regularization strength (`regularizer_strength`), and an optional mask array (`mask`).
Expand Down Expand Up @@ -149,7 +149,7 @@ When developing a functional (i.e., concrete) `Regularizer` class:
- **Must** inherit from `Regularizer` or one of its derivatives.
- **Must** implement the `instantiate_solver` method to tailor the solver instantiation based on the provided loss function.
- For any Proximal Gradient method, **must** include a `get_prox_operator` method to define the proximal operator.
- **Must** possess an `allowed_solvers` attribute to list the optimizer names that are permissible to be used with this solver.
- **Must** possess an `allowed_solvers` attribute to list the solver names that are permissible to be used with this regularizer.
- **May** embed additional attributes and methods such as `mask` and `_check_mask` if required by the specific Solver subclass for handling special optimization scenarios.
- **May** include a `regularizer_strength` attribute to control the strength of the regularization in scenarios where regularization is applicable.
- **May** rely on a custom solver implementation for specific optimization problems, but the implementation **must** adhere to the `jaxopt` API.
Expand Down
1 change: 0 additions & 1 deletion docs/examples/coupled_neurons_params.json
billbrod marked this conversation as resolved.
Outdated
Show resolved Hide resolved

This file was deleted.

150 changes: 108 additions & 42 deletions docs/examples/plot_glm_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,21 @@
data.

"""
import json

import jax
import matplotlib.pyplot as plt
import numpy as np
import sklearn.model_selection as sklearn_model_selection
from matplotlib.patches import Rectangle
from sklearn import model_selection

import nemos as nmo
from nemos import simulation

# Enable float64 precision (optional)
# enable float64 precision (optional)
jax.config.update("jax_enable_x64", True)

np.random.seed(111)
# Random design tensor. Shape (n_time_points, n_neurons, n_features).
# random design tensor. Shape (n_time_points, n_neurons, n_features).
X = 0.5*np.random.normal(size=(100, 1, 5))

# log-rates & weights, shape (n_neurons, ) and (n_neurons, n_features) respectively.
Expand Down Expand Up @@ -78,12 +78,12 @@
# ### Model Configuration
# One could visualize the model hyperparameters by calling `get_params` method.

# Get the glm model parameters only
# get the glm model parameters only
print("\nGLM model parameters:")
for key, value in model.get_params(deep=False).items():
print(f"\t- {key}: {value}")

# Get the glm model parameters, including the all the
# get the glm model parameters, including the all the
# attributes
print("\nNested parameters:")
for key, value in model.get_params(deep=True).items():
Expand Down Expand Up @@ -127,11 +127,11 @@

# %%
# !!! warning
# Each `Regularizer` has an associated attribute `Regularizer.allowed_optimizers`
# Each `Regularizer` has an associated attribute `Regularizer.allowed_solvers`
# which lists the optimizers that are suited for each optimization problem.
# For example, a `Ridge` is differentiable and can be fit with `GradientDescent`
# , `BFGS`, etc., while a `Lasso` should use the `ProximalGradient` method instead.
# If the provided `solver_name` is not listed in the `allowed_optimizers` this will raise an
# If the provided `solver_name` is not listed in the `allowed_solvers` this will raise an
# exception.

# %%
Expand All @@ -141,7 +141,7 @@
# Additionally one may provide an initial parameter guess.
# The same exact syntax works for any configuration.

# Fit a ridge regression Poisson GLM
# fit a ridge regression Poisson GLM
model = nmo.glm.GLM()
model.set_params(regularizer__regularizer_strength=0.1)
model.fit(X, spikes)
Expand All @@ -162,7 +162,9 @@
# **Ridge**

parameter_grid = {"regularizer__regularizer_strength": np.logspace(-1.5, 1.5, 6)}
cls = sklearn_model_selection.GridSearchCV(model, parameter_grid, cv=5)
# in practice, you should use more folds than 2, but for the purposes of this
# demo, 2 is sufficient.
cls = model_selection.GridSearchCV(model, parameter_grid, cv=2)
cls.fit(X, spikes)

print("Ridge results ")
Expand All @@ -176,7 +178,7 @@
# **Lasso**

model.set_params(regularizer=nmo.regularizer.Lasso())
cls = sklearn_model_selection.GridSearchCV(model, parameter_grid, cv=5)
cls = model_selection.GridSearchCV(model, parameter_grid, cv=2)
cls.fit(X, spikes)

print("Lasso results ")
Expand All @@ -194,7 +196,7 @@

regularizer = nmo.regularizer.GroupLasso("ProximalGradient", mask=mask)
model.set_params(regularizer=regularizer)
cls = sklearn_model_selection.GridSearchCV(model, parameter_grid, cv=5)
cls = model_selection.GridSearchCV(model, parameter_grid, cv=2)
cls.fit(X, spikes)

print("\nGroup Lasso results")
Expand Down Expand Up @@ -223,48 +225,112 @@
# %%
# ## Recurrently Coupled GLM
# Defining a recurrent model follows the same syntax. In this example
# we will simulate two coupled neurons. and we will inject a transient
# we will simulate two coupled neurons, and we will inject a transient
# input driving the rate of one of the neurons.
#
# For brevity, we will import the model parameters instead of generating
# them on the fly.

# load parameters
with open("coupled_neurons_params.json", "r") as fh:
config_dict = json.load(fh)

# basis weights & intercept for the GLM (both coupling and feedforward)
# (the last coefficient is the weight of the feedforward input)
basis_coeff = np.asarray(config_dict["coef_"])[:, :-1]
# Neural population parameters
n_neurons = 2
coupling_filter_duration = 100

# %%
# We can now to define coupling filters that we will use to simulate
# the pairwise interactions between the neurons. We will model the
# filters as a difference of two Gamma probability density function.
# The negative component will capture inhibitory effects such as the
# refractory period of a neuron, while the positive component will
# describe excitation.

np.random.seed(101)

# Gamma parameter for the inhibitory component of the fi;ter
inhib_a = 1
inhib_b = 1

# Gamma parameters for the excitatory component of the filter
excit_a = np.random.uniform(1.1, 5, size=(n_neurons, n_neurons))
excit_b = np.random.uniform(1.1, 5, size=(n_neurons, n_neurons))

# define 2x2 coupling filters of the specific with create_temporal_filter
coupling_filter_bank = np.zeros((coupling_filter_duration, n_neurons, n_neurons))
for unit_i in range(n_neurons):
for unit_j in range(n_neurons):
coupling_filter_bank[:, unit_i, unit_j] = nmo.simulation.difference_of_gammas(
coupling_filter_duration,
inhib_a=inhib_a,
excit_a=excit_a[unit_i, unit_j],
inhib_b=inhib_b,
excit_b=excit_b[unit_i, unit_j],
)

# shrink the filters for simulation stability
coupling_filter_bank *= 0.8

# Mask the weights so that only the first neuron receives the imput
basis_coeff[:, 40:] = np.abs(basis_coeff[:, 40:]) * np.array([[1.], [0.]])
# %%
# If we represent our filters in terms of basis functions, we can simulate our network by
# directly calling the `simulate` method of the `nmo.glm.GLMRecurrent` class.

intercept = np.asarray(config_dict["intercept_"])
# define a basis function
n_basis_funcs = 20
basis = nmo.basis.RaisedCosineBasisLog(n_basis_funcs)

# basis function, inputs and initial spikes
coupling_basis = jax.numpy.asarray(config_dict["coupling_basis"])
feedforward_input = jax.numpy.asarray(config_dict["feedforward_input"])
init_spikes = jax.numpy.asarray(config_dict["init_spikes"])
# approximate the coupling filters in terms of the basis function
_, coupling_basis = basis.evaluate_on_grid(coupling_filter_bank.shape[0])
coupling_coeff = simulation.regress_filter(coupling_filter_bank, coupling_basis)
intercept = -4 * np.ones(n_neurons)

# %%
# We can explore visualize the coupling filters and the input.
# We can check that our approximation worked by plotting the original filters
# and the basis expansion

# plot coupling functions
n_basis_coupling = coupling_basis.shape[1]
fig, axs = plt.subplots(2,2)
fig, axs = plt.subplots(n_neurons, n_neurons)
plt.suptitle("Coupling filters")
for unit_i in range(2):
for unit_j in range(2):
axs[unit_i,unit_j].set_title(f"unit {unit_j} -> unit {unit_i}")
coeff = basis_coeff[unit_i, unit_j * n_basis_coupling: (unit_j + 1) * n_basis_coupling]
axs[unit_i, unit_j].plot(np.dot(coupling_basis, coeff))
for unit_i in range(n_neurons):
for unit_j in range(n_neurons):
axs[unit_i, unit_j].set_title(f"unit {unit_j} -> unit {unit_i}")
coeff = coupling_coeff[unit_i, unit_j]
axs[unit_i, unit_j].plot(coupling_filter_bank[:, unit_i, unit_j], label="gamma difference")
axs[unit_i, unit_j].plot(np.dot(coupling_basis, coeff), ls="--", color="k", label="basis function")
axs[0, 0].legend()
plt.tight_layout()

fig, axs = plt.subplots(1,1)
plt.title("Feedforward inputs")
plt.plot(feedforward_input[:, 0])
# %%
# Define a squared stimulus current for the first neuron, and no stimulus for
# the second neuron

# define a squared current parameters
simulation_duration = 1000
stimulus_onset = 200
stimulus_offset = 500
stimulus_intensity = 1.5

# create the input tensor of shape (n_samples, n_neurons, n_dimension_stimuli)
feedforward_input = np.zeros((simulation_duration, n_neurons, 1))
# inject square input to the first neuron only
feedforward_input[stimulus_onset: stimulus_offset, 0] = stimulus_intensity

# plot the input
fig, axs = plt.subplots(1,2)
plt.suptitle("Feedforward inputs")
axs[0].set_title("Input to neuron 0")
axs[0].plot(feedforward_input[:, 0])

axs[1].set_title("Input to neuron 1")
axs[1].plot(feedforward_input[:, 1])
axs[1].set_ylim(axs[0].get_ylim())


# the input for the simulation will be the dot product
# of input_coeff with the feedforward_input
input_coeff = np.ones((n_neurons, 1))

# stack the coefficients in a single matrix
basis_coeff = np.hstack((coupling_coeff.reshape(n_neurons, -1), input_coeff))

# initialize the spikes for the recurrent simulation
init_spikes = np.zeros((coupling_filter_duration, n_neurons))

# %%
# We can now simulate spikes by calling the `simulate_recurrent` method.
Expand All @@ -277,7 +343,7 @@
# call simulate, with both the recurrent coupling
# and the input
spikes, rates = model.simulate_recurrent(
random_key,
jax.random.PRNGKey(123),
feedforward_input=feedforward_input,
coupling_basis_matrix=coupling_basis,
init_y=init_spikes
Expand All @@ -298,8 +364,8 @@
p0, = plt.plot(rates[:, 0])
p1, = plt.plot(rates[:, 1])

plt.vlines(np.where(spikes[:, 0])[0], 0.00, 0.01, color=p0.get_color(), label="neu 0")
plt.vlines(np.where(spikes[:, 1])[0], -0.01, 0.00, color=p1.get_color(), label="neu 1")
plt.vlines(np.where(spikes[:, 0])[0], 0.00, 0.01, color=p0.get_color(), label="rate neuron 0")
plt.vlines(np.where(spikes[:, 1])[0], -0.01, 0.00, color=p1.get_color(), label="rate neuron 1")
plt.plot(np.exp(basis_coeff[0, -1] * feedforward_input[:, 0, 0] + intercept[0]), color='k', lw=0.8, label="stimulus")
ax.add_patch(patch)
plt.ylim(-0.011, .13)
Expand Down
1 change: 1 addition & 0 deletions src/nemos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@
observation_models,
regularizer,
sample_points,
simulation,
utils,
)
4 changes: 4 additions & 0 deletions src/nemos/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,8 @@ def evaluate(self, sample_pts: NDArray) -> NDArray:
The evaluation is performed by looping over each element and using `splev`
from SciPy to compute the basis values.
"""
(sample_pts,) = self._check_evaluate_input(sample_pts)

# add knots
knot_locs = self._generate_knots(sample_pts, 0.0, 1.0)

Expand Down Expand Up @@ -711,6 +713,8 @@ def evaluate(self, sample_pts: NDArray) -> NDArray:
SciPy to compute the basis values.

"""
(sample_pts,) = self._check_evaluate_input(sample_pts)

knot_locs = self._generate_knots(sample_pts, 0.0, 1.0, is_cyclic=True)

# for cyclic, do not repeat knots
Expand Down
Loading