Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Glm class restructure #41

Merged
merged 261 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from 193 commits
Commits
Show all changes
261 commits
Select commit Hold shift + click to select a range
5507c63
fixed naming and improved docstrings
BalzaniEdoardo Aug 21, 2023
b9cb88f
initialize parameters to none
BalzaniEdoardo Aug 21, 2023
451be3a
improved docstrings, started off tests
BalzaniEdoardo Aug 21, 2023
e02f722
test glm started
BalzaniEdoardo Aug 21, 2023
94efb2c
updated tests
BalzaniEdoardo Aug 23, 2023
d88ee96
test completed, need docstrings
BalzaniEdoardo Aug 23, 2023
f40b1e1
linted and removed dandi test
BalzaniEdoardo Aug 23, 2023
07c9e56
linted src
BalzaniEdoardo Aug 23, 2023
bcd7cf9
improved description of simulate
BalzaniEdoardo Aug 23, 2023
5fafb5d
improved description of simulate
BalzaniEdoardo Aug 23, 2023
28db00c
bugfixed config
BalzaniEdoardo Aug 23, 2023
405b17f
improved logic of fit, add check nans
BalzaniEdoardo Aug 24, 2023
ba07cff
added short docstrings to tests
BalzaniEdoardo Aug 24, 2023
3f4cab2
tested model_base
BalzaniEdoardo Aug 24, 2023
f2ac53e
started off the developers' notes?
BalzaniEdoardo Aug 25, 2023
ff2a549
base_model restructured and modified note
BalzaniEdoardo Aug 25, 2023
b9b5b85
added docstrings to model_base.py
BalzaniEdoardo Aug 25, 2023
096b306
massive refractoring, makes sure that score and simulate has a GLM-ty…
BalzaniEdoardo Aug 25, 2023
8d2622c
fixed example in Error
BalzaniEdoardo Aug 25, 2023
5d80e9d
linted all
BalzaniEdoardo Aug 25, 2023
6c3da95
switched residual deviance to public
BalzaniEdoardo Aug 25, 2023
a342808
bugfixed tests
BalzaniEdoardo Aug 26, 2023
91680e1
refractor classes names
BalzaniEdoardo Aug 26, 2023
c8dd031
bugfixed test_base
BalzaniEdoardo Aug 26, 2023
c56fd7e
added an end-to-end fit
BalzaniEdoardo Aug 28, 2023
7e21a36
improved docstrings of base_class
BalzaniEdoardo Aug 28, 2023
02be1da
improved notes
BalzaniEdoardo Aug 28, 2023
ffd640b
added scheme of model classes
BalzaniEdoardo Aug 29, 2023
9c6e135
linted glm.py, removed basic_test.py
BalzaniEdoardo Aug 29, 2023
80b0676
reviewed glm note
BalzaniEdoardo Aug 29, 2023
a0bafdc
update docstring for simulate
BalzaniEdoardo Aug 29, 2023
d349b0d
update docstring for simulate
BalzaniEdoardo Aug 29, 2023
5ff37e1
improved structure
BalzaniEdoardo Aug 29, 2023
6a5840e
imrpoved docstrings
BalzaniEdoardo Aug 29, 2023
29a8dfd
linted
BalzaniEdoardo Aug 29, 2023
32277b5
added device check on fit
BalzaniEdoardo Aug 31, 2023
9567546
Update 02-base_class.md
BalzaniEdoardo Sep 6, 2023
d6458b2
improved exception messages
BalzaniEdoardo Sep 6, 2023
31d1b85
flexible data_type
BalzaniEdoardo Sep 6, 2023
a9c2a9e
linted
BalzaniEdoardo Sep 6, 2023
f61b1f7
data_type as parameter
BalzaniEdoardo Sep 6, 2023
3034a9a
linted
BalzaniEdoardo Sep 6, 2023
aec419e
improved docstrings
BalzaniEdoardo Sep 6, 2023
0d5e65a
pydoctest linting
BalzaniEdoardo Sep 6, 2023
1c54339
pydoctest linting
BalzaniEdoardo Sep 6, 2023
c415c68
started loss and obs noise
BalzaniEdoardo Sep 7, 2023
0912e91
added tests
BalzaniEdoardo Sep 9, 2023
7e613d9
added comprehensive testing
BalzaniEdoardo Sep 9, 2023
7cfd5c2
add code of conduct
BalzaniEdoardo Sep 9, 2023
b22de56
fixed typing
BalzaniEdoardo Sep 10, 2023
4eddb67
use jax config as default dtype
BalzaniEdoardo Sep 10, 2023
79b5997
make sure the mask is configured before the runner is called
BalzaniEdoardo Sep 10, 2023
7504215
removed unused deps
BalzaniEdoardo Sep 10, 2023
f341c86
glm edited after discussion
BalzaniEdoardo Sep 11, 2023
f56bd4f
fixed example
BalzaniEdoardo Sep 11, 2023
779bf2d
removed firing and spiking
BalzaniEdoardo Sep 12, 2023
6d43fb4
start test refractoring
BalzaniEdoardo Sep 13, 2023
3fc04b7
fixed linking of noise model
BalzaniEdoardo Sep 13, 2023
db8c9c0
imporoved text
BalzaniEdoardo Sep 13, 2023
93c3cb6
refractor of test to be continued
BalzaniEdoardo Sep 13, 2023
7fc6d6e
refractored glm tests with helper function
BalzaniEdoardo Sep 14, 2023
f87e01d
removed data_type, improved docstrings, fixed tests
BalzaniEdoardo Sep 18, 2023
0c2d173
fixed deprecations
BalzaniEdoardo Sep 19, 2023
788f3d4
added tests for corner cases
BalzaniEdoardo Sep 19, 2023
ee00d13
set solver_name and solver_kwargs as propertied
BalzaniEdoardo Sep 19, 2023
263c530
set noise model link func as a property
BalzaniEdoardo Sep 19, 2023
3cfc0b0
linted
BalzaniEdoardo Sep 19, 2023
e430c35
bugfixed tests (compiled jax funcs start with jaxlib)
BalzaniEdoardo Sep 19, 2023
55ec5e3
added tests for unreg solver
BalzaniEdoardo Sep 19, 2023
064222d
added unreg tests and pointed to src for coverage
BalzaniEdoardo Sep 19, 2023
b37e9b5
added unreg tests and pointed to src for coverage
BalzaniEdoardo Sep 19, 2023
01c86d0
improved docs
BalzaniEdoardo Sep 20, 2023
753a64d
edited 02-base_class.md note
BalzaniEdoardo Sep 20, 2023
29eb4ba
edited note 03-glm.md
BalzaniEdoardo Sep 20, 2023
23e0557
edited note 03-glm.md
BalzaniEdoardo Sep 21, 2023
6e83ede
bugfixed plots
BalzaniEdoardo Sep 21, 2023
5f8b78a
first pass to the noise model dev note
BalzaniEdoardo Sep 21, 2023
b30b48f
completed note 04-noise_model.md
BalzaniEdoardo Sep 21, 2023
ad99760
started new note on solver
BalzaniEdoardo Sep 22, 2023
3211e8a
minor edits and grammar checks
BalzaniEdoardo Sep 22, 2023
8e39d5e
added a glossary
BalzaniEdoardo Sep 22, 2023
2b2962d
edited note
BalzaniEdoardo Sep 22, 2023
0aec19d
edited notes, added hyeperlinks
BalzaniEdoardo Sep 22, 2023
4e3f51a
linted
BalzaniEdoardo Sep 22, 2023
3767617
added one line docstrings for abstract methods
BalzaniEdoardo Sep 22, 2023
34e2570
linted docstrings
BalzaniEdoardo Sep 22, 2023
c9df746
linted docstrings with pydocstyle
BalzaniEdoardo Sep 22, 2023
fc962a5
mypy fixes
BalzaniEdoardo Sep 22, 2023
1cf9d07
changed notes order
BalzaniEdoardo Sep 22, 2023
b0a9618
fixed examples
BalzaniEdoardo Sep 22, 2023
48e9e81
restart the ci
BalzaniEdoardo Sep 22, 2023
00b2401
show inherited methods in docs
BalzaniEdoardo Sep 22, 2023
6dee6af
merged main
BalzaniEdoardo Sep 22, 2023
0cd45f9
added new modules to index
BalzaniEdoardo Sep 26, 2023
44e5558
Update docs/developers_notes/02-base_class.md
BalzaniEdoardo Oct 26, 2023
b8ca932
Update docs/developers_notes/02-base_class.md
BalzaniEdoardo Oct 26, 2023
65b2df4
changed CI to edit with isort and not just check
BalzaniEdoardo Oct 26, 2023
faf0e60
removed index
BalzaniEdoardo Oct 26, 2023
89112fd
replaces yaml with json. removed yaml dependency
BalzaniEdoardo Oct 26, 2023
226a752
Update mkdocs.yml
BalzaniEdoardo Oct 26, 2023
23f1cc5
removed leftover filter in mkdocstrings
BalzaniEdoardo Oct 26, 2023
099d517
testing coverage
BalzaniEdoardo Oct 30, 2023
b363705
Merge branch 'glm_class_restructure' of github.com:flatironinstitute/…
BalzaniEdoardo Oct 30, 2023
ede53f1
switched to json in tests
BalzaniEdoardo Oct 30, 2023
e43991c
requirement updated
BalzaniEdoardo Oct 30, 2023
e336d9e
bugfix version
BalzaniEdoardo Oct 30, 2023
a4eb0eb
coverage over src
BalzaniEdoardo Oct 30, 2023
c868f68
changed function naming
BalzaniEdoardo Oct 30, 2023
e434236
refractored name to Base
BalzaniEdoardo Oct 30, 2023
6c39aa3
removed hidden attributes
BalzaniEdoardo Oct 30, 2023
6020290
refer to jax
BalzaniEdoardo Oct 30, 2023
10af2fe
moved conversion to utils.py
BalzaniEdoardo Oct 30, 2023
379ae4c
moved check invalid entry to utils
BalzaniEdoardo Oct 31, 2023
3229c05
Update src/neurostatslib/base_class.py
BalzaniEdoardo Oct 31, 2023
255d969
Update src/neurostatslib/glm.py
BalzaniEdoardo Oct 31, 2023
489853f
linted
BalzaniEdoardo Oct 31, 2023
89df69c
changed mock tests and changed get_params
BalzaniEdoardo Oct 31, 2023
914ce1b
fixed typo
BalzaniEdoardo Oct 31, 2023
23310cf
create a check array function
BalzaniEdoardo Oct 31, 2023
94d5523
preprocess func private
BalzaniEdoardo Oct 31, 2023
b50017e
fixed tests
BalzaniEdoardo Oct 31, 2023
759a23d
renamed paramters in simulate preproc
BalzaniEdoardo Oct 31, 2023
9fef6c8
linted
BalzaniEdoardo Oct 31, 2023
f10e4e9
linted
BalzaniEdoardo Oct 31, 2023
3152065
added test for noise model link
BalzaniEdoardo Oct 31, 2023
9b55807
Update src/neurostatslib/noise_model.py
BalzaniEdoardo Oct 31, 2023
bbe8213
Update src/neurostatslib/noise_model.py
BalzaniEdoardo Oct 31, 2023
0f7d67b
linted
BalzaniEdoardo Oct 31, 2023
293172f
Merge branch 'glm_class_restructure' of github.com:flatironinstitute/…
BalzaniEdoardo Oct 31, 2023
6cc4822
renamed emission_probability
BalzaniEdoardo Oct 31, 2023
2aae9ef
Update src/neurostatslib/noise_model.py
BalzaniEdoardo Nov 1, 2023
19456b1
Update src/neurostatslib/glm.py
BalzaniEdoardo Nov 1, 2023
a1f1655
Merge branch 'glm_class_restructure' of github.com:flatironinstitute/…
BalzaniEdoardo Nov 1, 2023
aa454fb
Update src/neurostatslib/glm.py
BalzaniEdoardo Nov 1, 2023
f74d3ab
Update src/neurostatslib/glm.py
BalzaniEdoardo Nov 1, 2023
b29b6d7
Update src/neurostatslib/solver.py
BalzaniEdoardo Nov 1, 2023
dbdca86
Update src/neurostatslib/glm.py
BalzaniEdoardo Nov 1, 2023
881d503
Update src/neurostatslib/solver.py
BalzaniEdoardo Nov 1, 2023
bab3a9b
Merge branch 'glm_class_restructure' of github.com:flatironinstitute/…
BalzaniEdoardo Nov 1, 2023
c340dd2
linted
BalzaniEdoardo Nov 1, 2023
448d370
removed mask from proximal operator
BalzaniEdoardo Nov 1, 2023
61a2511
fixed solver mask
BalzaniEdoardo Nov 1, 2023
0183d9e
removed the renaming
BalzaniEdoardo Nov 1, 2023
ab3cd99
removed the any reference to noise
BalzaniEdoardo Nov 1, 2023
2991896
allow a loss function with 3 arguments with any variable names
BalzaniEdoardo Nov 2, 2023
2fdd66f
linted
BalzaniEdoardo Nov 2, 2023
5d43790
grammar fixed
BalzaniEdoardo Nov 2, 2023
7c12183
changed description of check func
BalzaniEdoardo Nov 2, 2023
27582cc
generalized get_runner to accept any inputs
BalzaniEdoardo Nov 3, 2023
bdefe97
simplified get_runner
BalzaniEdoardo Nov 3, 2023
86c4df8
simplified get_runner
BalzaniEdoardo Nov 3, 2023
8af2a21
solver set
BalzaniEdoardo Nov 3, 2023
b27cd31
modified glm moving the checks to the setter
BalzaniEdoardo Nov 3, 2023
3baaefb
added logistic as example
BalzaniEdoardo Nov 3, 2023
d3664c3
improved docstring for score
BalzaniEdoardo Nov 3, 2023
10255c6
bugfixed deviance
BalzaniEdoardo Nov 3, 2023
ecb069e
renamed deviance
BalzaniEdoardo Nov 3, 2023
662dfbc
moved raise in the if close
BalzaniEdoardo Nov 4, 2023
3c0fb3f
commented scanf
BalzaniEdoardo Nov 4, 2023
dbd3bbd
improved scanf algorithm
BalzaniEdoardo Nov 6, 2023
5151f7a
improved comments
BalzaniEdoardo Nov 6, 2023
b7aaa25
add test loss is callable
BalzaniEdoardo Nov 6, 2023
355199d
linted
BalzaniEdoardo Nov 6, 2023
dc02446
completed first round code revisions
BalzaniEdoardo Nov 7, 2023
c838e49
refractor parameter names
BalzaniEdoardo Nov 7, 2023
9d7ccd2
added a warning on the demo
BalzaniEdoardo Nov 7, 2023
7570908
checked conversion to nemos
BalzaniEdoardo Nov 7, 2023
2b87edc
no change
BalzaniEdoardo Nov 7, 2023
13135ce
removed refs to neurostatslib
BalzaniEdoardo Nov 7, 2023
14a17b8
removed refs to neurostatslib
BalzaniEdoardo Nov 7, 2023
b315f46
removed refs to neurostatslib
BalzaniEdoardo Nov 7, 2023
1f408fa
removed gen lin models and linted
BalzaniEdoardo Nov 7, 2023
77065a0
linted
BalzaniEdoardo Nov 7, 2023
b902219
removed gen-lin-mod
BalzaniEdoardo Nov 7, 2023
3aee269
added dev deps
BalzaniEdoardo Nov 7, 2023
cb381ff
commit pyproject.toml
BalzaniEdoardo Nov 7, 2023
db62140
linted
BalzaniEdoardo Nov 7, 2023
724522f
fixed links
BalzaniEdoardo Nov 7, 2023
15f062c
fixed links
BalzaniEdoardo Nov 7, 2023
1ef0eda
linted
BalzaniEdoardo Nov 7, 2023
a2a5cb7
fixed spacing
BalzaniEdoardo Nov 8, 2023
4ea0276
fixed the shapes
BalzaniEdoardo Nov 8, 2023
2a389cf
removed unused funcs
BalzaniEdoardo Nov 21, 2023
252b78b
fixed test
BalzaniEdoardo Nov 21, 2023
4592cfd
removed convert utils
BalzaniEdoardo Nov 21, 2023
38d6d51
fixed tests exceptions
BalzaniEdoardo Nov 21, 2023
a2981bc
fixed tests exceptions
BalzaniEdoardo Nov 21, 2023
59393f9
fixed regex tests
BalzaniEdoardo Nov 21, 2023
d40a272
fixed regex tests glm
BalzaniEdoardo Nov 21, 2023
3b43e04
fixed hyperlink
BalzaniEdoardo Nov 22, 2023
84e8925
improved _score docstrings
BalzaniEdoardo Nov 22, 2023
9debb46
added 2 types of pr2
BalzaniEdoardo Nov 24, 2023
2bdd293
float_eps removed
BalzaniEdoardo Nov 27, 2023
06dc039
moved multi-array device put to utils.py
BalzaniEdoardo Nov 27, 2023
ef170e3
improved refs
BalzaniEdoardo Nov 27, 2023
3a0e990
refractored names
BalzaniEdoardo Nov 27, 2023
412727e
fixed tests
BalzaniEdoardo Nov 27, 2023
5ec8430
improved refs to solver
BalzaniEdoardo Nov 27, 2023
5665711
refractor note solver part1
BalzaniEdoardo Nov 27, 2023
266c077
refractor module name
BalzaniEdoardo Nov 28, 2023
7362b0c
improved docstring proximal operator
BalzaniEdoardo Nov 28, 2023
1b2e055
removed ## on module docstrings
BalzaniEdoardo Nov 29, 2023
62be95f
removed note on prox op, updated baseclass
BalzaniEdoardo Nov 29, 2023
8c61fd1
removed get_runner and use super().instantiate_solver instead
BalzaniEdoardo Nov 29, 2023
46c2a58
edited note
BalzaniEdoardo Nov 29, 2023
3d1c113
switched to fixture
BalzaniEdoardo Nov 30, 2023
eb0a677
started refractoring tests
BalzaniEdoardo Nov 30, 2023
c9d6f0a
refractored test_fit_weights_dimensionality
BalzaniEdoardo Nov 30, 2023
42e8a2e
refractored test_fit_intercepts_dimensionality
BalzaniEdoardo Nov 30, 2023
c3e7b29
refractored test_fit_init_params_type
BalzaniEdoardo Nov 30, 2023
332a7eb
refractored test_fit_n_neuron_match_weights
BalzaniEdoardo Nov 30, 2023
90bbc87
refractored test_fit_n_neuron_match_baseline_rate
BalzaniEdoardo Nov 30, 2023
6e49aa6
refractored test_fit_n_neuron_match_baseline_rate
BalzaniEdoardo Nov 30, 2023
2a043f0
refractored test_fit_n_neuron_match_x
BalzaniEdoardo Nov 30, 2023
2468b73
refractored test_fit_n_neuron_match_y
BalzaniEdoardo Nov 30, 2023
4b6565b
refractored test_fit_x_dimensionality
BalzaniEdoardo Nov 30, 2023
0513540
refractored test_fit_y_dimensionality
BalzaniEdoardo Nov 30, 2023
fb862da
refractored test_fit_n_feature_consistency_weights
BalzaniEdoardo Nov 30, 2023
88fe8c8
refractored test_fit_n_feature_consistency_x
BalzaniEdoardo Nov 30, 2023
34b9b84
refractored test_fit_time_points_x
BalzaniEdoardo Nov 30, 2023
1fc86b2
test_fit_time_points_y
BalzaniEdoardo Nov 30, 2023
0071c07
test_score_n_neuron_match_x
BalzaniEdoardo Nov 30, 2023
71e6976
test_score_n_neuron_match_y
BalzaniEdoardo Nov 30, 2023
56e86b0
test_score_x_dimensionality
BalzaniEdoardo Nov 30, 2023
1b2215f
completed refractoring
BalzaniEdoardo Nov 30, 2023
2786607
linted everything
BalzaniEdoardo Nov 30, 2023
85de28f
fixed naming
BalzaniEdoardo Nov 30, 2023
41f3771
linted glm
BalzaniEdoardo Nov 30, 2023
29aab97
removed sklearn dep
BalzaniEdoardo Nov 30, 2023
979b3b4
removed unused method
BalzaniEdoardo Dec 11, 2023
899ca98
refractored names
BalzaniEdoardo Dec 11, 2023
d5c220e
undoes some unhelpful python black ormatting
billbrod Dec 11, 2023
2e2c071
fixed docstrings
BalzaniEdoardo Dec 11, 2023
9645f9d
fix docstring
billbrod Dec 11, 2023
1ee0de4
choen -> cohen
billbrod Dec 11, 2023
4cc2d34
added statsmodels equivalence in docstrings
BalzaniEdoardo Dec 11, 2023
e76ce1a
Merge branch 'glm_class_restructure' of github.com:flatironinstitute/…
BalzaniEdoardo Dec 11, 2023
40249ea
resolved conflicts
BalzaniEdoardo Dec 11, 2023
d016665
linted flake8
BalzaniEdoardo Dec 11, 2023
e9bd25c
generate parameters for simulaiton
BalzaniEdoardo Dec 12, 2023
4679dc2
removed json in docs
BalzaniEdoardo Dec 12, 2023
5d32df9
modified the glm demo
BalzaniEdoardo Dec 13, 2023
2930373
added testing
BalzaniEdoardo Dec 14, 2023
46092ed
fixed test by removing dep on json
BalzaniEdoardo Dec 14, 2023
2a815b4
added test of correctness for the lsq
BalzaniEdoardo Dec 14, 2023
58a5aba
added raises
BalzaniEdoardo Dec 14, 2023
0d1207c
Update src/nemos/simulation.py
BalzaniEdoardo Dec 14, 2023
c2dd948
Update src/nemos/simulation.py
BalzaniEdoardo Dec 14, 2023
3034922
Update src/nemos/simulation.py
BalzaniEdoardo Dec 14, 2023
2a5259b
Merge branch 'glm_class_restructure' of github.com:flatironinstitute/…
BalzaniEdoardo Dec 14, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,6 @@ docs/generated/

# vscode
.vscode/

# nwb cahce
nwb-cache/
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# The Basis Module
# The `basis` Module

## Introduction

Expand Down
102 changes: 102 additions & 0 deletions docs/developers_notes/02-base_class.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# The `base_class` Module

## Introduction

The `base_class` module introduces the `Base` class and abstract classes defining broad model categories. These abstract classes **must** inherit from `Base`.

The `Base` class is envisioned as the foundational component for any object type (e.g., regression, dimensionality reduction, clustering, observation models, solvers etc.). In contrast, abstract classes derived from `Base` define overarching object categories (e.g., `base_class.BaseRegressor` is building block for GLMs, GAMS, etc. while `observation_models.Observations` is the building block for the Poisson observations, Gamma observations, ... etc.).

Designed to be compatible with the `scikit-learn` API, the class structure aims to facilitate access to `scikit-learn`'s robust pipeline and cross-validation modules. This is achieved while leveraging the accelerated computational capabilities of `jax` and `jaxopt` in the backend, which is essential for analyzing extensive neural recordings and fitting large models.

Below a scheme of how we envision the architecture of the `nemos` models.

```
Abstract Class Base
├─ Abstract Subclass BaseRegressor
│ │
│ └─ Concrete Subclass GLM
│ │
│ └─ Concrete Subclass RecurrentGLM
├─ Abstract Subclass BaseManifold *(not implemented yet)
│ │
│ ...
├─ Abstract Subclass Solver
│ │
│ ├─ Concrete Subclass UnRegularizedSolver
│ │
│ ├─ Concrete Subclass RidgeSolver
│ ...
├─ Abstract Subclass Observations
│ │
│ ├─ Concrete Subclass PoissonObservations
│ │
│ ├─ Concrete Subclass GammaObservations *(not implemented yet)
│ ...
...
```

!!! Example
The current package version includes a concrete class named `nemos.glm.GLM`. This class inherits from `BaseRegressor`, which in turn inherits `Base`, since it falls under the " GLM regression" category.
As any `BaseRegressor`, it **must** implement the `fit`, `score`, `predict`, and `simulate` methods.


## The Class `model_base.Base`

The `Base` class aligns with the `scikit-learn` API for `base.BaseEstimator`. This alignment is achieved by implementing the `get_params` and `set_params` methods, essential for `scikit-learn` compatibility and foundational for all model implementations. Additionally, the class provides auxiliary helper methods to identify available computational devices (such as GPUs and TPUs) and to facilitate data transfer to these devices.

For a detailed understanding, consult the [`scikit-learn` API Reference](https://scikit-learn.org/stable/modules/classes.html) and [`BaseEstimator`](https://scikit-learn.org/stable/modules/generated/sklearn.base.BaseEstimator.html).

!!! Note
We've intentionally omitted the `get_metadata_routing` method. Given its current experimental status and its lack of relevance to the `GLM` class, this method was excluded. Should future needs arise around parameter routing, consider directly inheriting from `sklearn.BaseEstimator`. More information can be found [here](https://scikit-learn.org/stable/metadata_routing.html#metadata-routing).

### Public methods

- **`get_params`**: The `get_params` method retrieves parameters set during model instance initialization. Opting for a deep inspection allows the method to assess nested object parameters, resulting in a comprehensive parameter dictionary.
- **`set_params`**: The `set_params` method offers a mechanism to adjust or set an estimator's parameters. It's versatile, accommodating both individual estimators and more complex nested structures like pipelines. Feeding an unrecognized parameter will raise a `ValueError`.
- **`select_target_device`**: Selects either "cpu", "gpu" or "tpu" as the device.
- **`device_put`**: Sends arrays to device, if not on device already.

## The Abstract Class `model_base.BaseRegressor`

`BaseRegressor` is an abstract class that inherits from `Base`, stipulating the implementation of abstract methods: `fit`, `predict`, `score`, and `simulate`. This ensures seamless assimilation with `scikit-learn` pipelines and cross-validation procedures.

### Abstract Methods

For subclasses derived from `BaseRegressor` to function correctly, they must implement the following:

1. `fit`: Adapt the model using input data `X` and corresponding observations `y`.
2. `predict`: Provide predictions based on the trained model and input data `X`.
3. `score`: Score the accuracy of model predictions using input data `X` against the actual observations `y`.
4. `simulate`: Simulate data based on the trained regression model.

### Public Methods

To ensure the consistency and conformity of input data, the `BaseRegressor` introduces two public preprocessing methods:

1. `preprocess_fit`: Assesses and converts the input for the `fit` method into the desired `jax.ndarray` format. If necessary, this method can initialize model parameters using default values.
2. `preprocess_simulate`: Validates and converts inputs for the `simulate` method. This method confirms the integrity of the feedforward input and, when provided, the initial values for feedback.

### Auxiliary Methods

Moreover, `BaseRegressor` incorporates auxiliary methods such as `_convert_to_jnp_ndarray`, `_has_invalid_entry`
and a number of other methods for checking input consistency.

!!! Tip
Deciding between concrete and abstract methods in a superclass can be nuanced. As a general guideline: any method that's expected in all subclasses and isn't subclass-specific should be concretely implemented in the superclass. Conversely, methods essential for a subclass's expected behavior, but vary based on the subclass, should be abstract in the superclass. For instance, compatibility with the `sklearn.cross_validation` module demands `score`, `fit`, `get_params`, and `set_params` methods. Given their specificity to individual models, `score` and `fit` are abstract in `BaseRegressor`. Conversely, as `get_params` and `set_params` are consistent across model classes, they're inherited from `Base`. This approach typifies our general implementation strategy. However, it's important to note that while these are sound guidelines, exceptions exist based on various factors like future extensibility, clarity, and maintainability.


## Contributor Guidelines

### Implementing Model Subclasses

When devising a new model subclass based on the `BaseRegressor` abstract class, adhere to the subsequent guidelines:

- **Must** inherit the `BaseRegressor` abstract superclass.
- **Must** realize the abstract methods: `fit`, `predict`, `score`, and `simulate`.
- **Should not** overwrite the `get_params` and `set_params` methods, inherited from `Base`.
- **May** introduce auxiliary methods such as `_convert_to_jnp_ndarray` for added utility.
64 changes: 64 additions & 0 deletions docs/developers_notes/03-observation_models.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# The `observation_models` Module

## Introduction

The `observation_models` module provides objects representing the observations of GLM-like models.

The abstract class `Observations` defines the structure of the subclasses which specify observation types, such as Poisson, Gamma, etc. These objects serve as attributes of the [`nemos.glm.GLM`](../05-glm/#the-concrete-class-glm) class, equipping the GLM with a negative log-likelihood. This is used to define the optimization objective, the deviance which measures model fit quality, and the emission of new observations, for simulating new data.

## The Abstract class `Observations`

The abstract class `Observations` is the backbone of any observation model. Any class inheriting `Observations` must reimplement the `negative_log_likelihood`, `sample_generator`, `residual_deviance`, and `estimate_scale` methods.

### Abstract Methods

For subclasses derived from `Observations` to function correctly, they must implement the following:

- **negative_log_likelihood**: Computes the negative-log likelihood of the model up to a normalization constant. This method is usually part of the objective function used to learn GLM parameters.

- **sample_generator**: Returns the random emission probability function. This typically invokes `jax.random` emission probability, provided some sufficient statistics[^1]. For distributions in the exponential family, the sufficient statistics are the canonical parameter and the scale. In GLMs, the canonical parameter is entirely specified by the model's weights, while the scale is either fixed (i.e., Poisson) or needs to be estimated (i.e., Gamma).

- **residual_deviance**: Computes the residual deviance based on the model's estimated rates and observations.

- **estimate_scale**: A method for estimating the scale parameter of the model.

### Public Methods

- **pseudo_r2**: Method for computing the pseudo-$R^2$ of the model based on the residual deviance. There is no consensus definition for the pseudo-$R^2$, what we used here is the definition by Choen at al. 2003[^2].


### Auxiliary Methods

- **_check_inverse_link_function**: Check that the provided link function is a `Callable` of the `jax` namespace.

## Concrete `PoissonObservations` class

The `PoissonObservations` class extends the abstract `Observations` class to provide functionalities specific to the Poisson observation model. It is designed for modeling observed spike counts based on a Poisson distribution with a given rate.

### Overridden Methods

- **negative_log_likelihood**: This method computes the Poisson negative log-likelihood of the predicted rates for the observed spike counts.

- **sample_generator**: Generates random numbers from a Poisson distribution based on the given `predicted_rate`.

- **residual_deviance**: Calculates the residual deviance for a Poisson model.

- **estimate_scale**: Assigns a fixed value of 1 to the scale parameter of the Poisson model since Poisson distribution has a fixed scale.

## Contributor Guidelines

To implement an observation model class you

- **Must** inherit from `Observations`

- **Must** provide a concrete implementation of `negative_log_likelihood`, `sample_generator`, `residual_deviance`, and `estimate_scale`.

- **Should not** reimplement the `pseudo_r2` method as well as the `_check_inverse_link_function` auxiliary method.

[^1]:
In statistics, a statistic is sufficient with respect to a statistical model and its associated unknown parameters if "no other statistic that can be calculated from the same sample provides any additional information as to the value of the parameters", adapted from Fisher R. A.
1922. On the mathematical foundations of theoretical statistics. *Philosophical Transactions of the Royal Society of London. Series A, Containing Papers of a Mathematical or Physical Character* 222:309–368. http://doi.org/10.1098/rsta.1922.0009.
[^2]:
Jacob Cohen, Patricia Cohen, Steven G. West, Leona S. Aiken.
*Applied Multiple Regression/Correlation Analysis for the Behavioral Sciences*.
3rd edition. Routledge, 2002. p.502. ISBN 978-0-8058-2223-6. (May 2012)
Loading