Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added linting tests to tox.ini and linted #258

Merged
merged 1 commit into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/test_base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def score(self, X, y, score_type="pseudo-r2-McFadden"):
def get_optimal_solver_params_config(self):
return None, None, None


class BadEstimator(Base):
def __init__(self, param1, *args):
super().__init__()
Expand Down
8 changes: 5 additions & 3 deletions tests/test_convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,14 @@ def test_ridge_convergence(solver_names):
y = np.random.poisson(rate)

# instantiate and fit ridge GLM with GradientDescent
model_GD = nmo.glm.GLM(regularizer_strength=1., regularizer="Ridge", solver_kwargs=dict(tol=10**-12))
model_GD = nmo.glm.GLM(
regularizer_strength=1.0, regularizer="Ridge", solver_kwargs=dict(tol=10**-12)
)
model_GD.fit(X, y)

# instantiate and fit ridge GLM with ProximalGradient
model_PG = nmo.glm.GLM(
regularizer_strength=1.,
regularizer_strength=1.0,
regularizer="Ridge",
solver_name="ProximalGradient",
solver_kwargs=dict(tol=10**-12),
Expand Down Expand Up @@ -109,7 +111,7 @@ def test_lasso_convergence(solver_name):
# instantiate and fit GLM with ProximalGradient
model_PG = nmo.glm.GLM(
regularizer="Lasso",
regularizer_strength=1.,
regularizer_strength=1.0,
solver_name="ProximalGradient",
solver_kwargs=dict(tol=10**-12),
)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,9 @@ def test_tree_structure_match(self, trial_counts, axis):
conv = convolve.create_convolutional_predictor(
basis_matrix, trial_counts, axis=axis
)
assert jax.tree_util.tree_structure(trial_counts) == jax.tree_util.tree_structure(conv)
assert jax.tree_util.tree_structure(
trial_counts
) == jax.tree_util.tree_structure(conv)

@pytest.mark.parametrize("axis", [0, 1, 2])
@pytest.mark.parametrize(
Expand Down
13 changes: 10 additions & 3 deletions tests/test_glm_initialization.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import warnings

import jax
import jax.numpy as jnp
import numpy as np
import pytest

import nemos as nmo
import warnings


@pytest.mark.parametrize(
"non_linearity",
Expand Down Expand Up @@ -56,14 +58,19 @@ def test_initialization_error_nan_input(non_linearity, expectation):
inverse_link_function=non_linearity, y=output_y
)


def test_initialization_error_non_invertible():
"""Initialize invalid."""
output_y = np.random.uniform(size=100)
inv_link = lambda x: jax.nn.softplus(x) + 10
with pytest.raises(ValueError, match="Failed to initialize the model intercept.+Please, provide"):
with pytest.raises(
ValueError, match="Failed to initialize the model intercept.+Please, provide"
):
with warnings.catch_warnings():
# ignore the warning raised by the root-finder (there is no root)
warnings.filterwarnings("ignore", category=RuntimeWarning, message="Tolerance of")
warnings.filterwarnings(
"ignore", category=RuntimeWarning, message="Tolerance of"
)
nmo.initialize_regressor.initialize_intercept_matching_mean_rate(
inverse_link_function=inv_link, y=output_y
)
88 changes: 66 additions & 22 deletions tests/test_observation_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ def test_pseudo_r2_range(self, score_type, poissonGLM_model_instantiation):
Compute the pseudo-r2 and check that is < 1.
"""
_, y, model, _, firing_rate = poissonGLM_model_instantiation
pseudo_r2 = model.observation_model.pseudo_r2(y, firing_rate, score_type=score_type)
pseudo_r2 = model.observation_model.pseudo_r2(
y, firing_rate, score_type=score_type
)
if (pseudo_r2 > 1) or (pseudo_r2 < 0):
raise ValueError(f"pseudo-r2 of {pseudo_r2} outside the [0,1] range!")

Expand All @@ -157,7 +159,9 @@ def test_pseudo_r2_mean(self, score_type, poissonGLM_model_instantiation):
Check that the pseudo-r2 of the null model is 0.
"""
_, y, model, _, _ = poissonGLM_model_instantiation
pseudo_r2 = model.observation_model.pseudo_r2(y, y.mean(), score_type=score_type)
pseudo_r2 = model.observation_model.pseudo_r2(
y, y.mean(), score_type=score_type
)
if not np.allclose(pseudo_r2, 0, atol=10**-7, rtol=0.0):
raise ValueError(
f"pseudo-r2 of {pseudo_r2} for the null model. Should be equal to 0!"
Expand Down Expand Up @@ -241,7 +245,9 @@ def test_pseudo_r2_vs_statsmodels(self, poissonGLM_model_instantiation):
pr2_sms = mdl.pseudo_rsquared("mcf")

# set params
pr2_model = model.observation_model.pseudo_r2(y, mdl.mu, score_type="pseudo-r2-McFadden")
pr2_model = model.observation_model.pseudo_r2(
y, mdl.mu, score_type="pseudo-r2-McFadden"
)

if not np.allclose(pr2_model, pr2_sms):
raise ValueError("Log-likelihood doesn't match statsmodels!")
Expand All @@ -254,27 +260,43 @@ def test_aggregation_score_neg_ll(self, poissonGLM_model_instantiation):

def test_aggregation_score_ll(self, poissonGLM_model_instantiation):
X, y, model, _, firing_rate = poissonGLM_model_instantiation
sm = model.observation_model.log_likelihood(y, firing_rate, aggregate_sample_scores=jnp.sum)
mn = model.observation_model.log_likelihood(y, firing_rate, aggregate_sample_scores=jnp.mean)
sm = model.observation_model.log_likelihood(
y, firing_rate, aggregate_sample_scores=jnp.sum
)
mn = model.observation_model.log_likelihood(
y, firing_rate, aggregate_sample_scores=jnp.mean
)
assert np.allclose(sm, mn * y.shape[0])

@pytest.mark.parametrize("score_type", ["pseudo-r2-McFadden", "pseudo-r2-Cohen"])
def test_aggregation_score_pr2(self, score_type, poissonGLM_model_instantiation):
X, y, model, _, firing_rate = poissonGLM_model_instantiation
sm = model.observation_model.pseudo_r2(y, firing_rate, score_type=score_type, aggregate_sample_scores=jnp.sum)
mn = model.observation_model.pseudo_r2(y, firing_rate, score_type=score_type, aggregate_sample_scores=jnp.mean)
sm = model.observation_model.pseudo_r2(
y, firing_rate, score_type=score_type, aggregate_sample_scores=jnp.sum
)
mn = model.observation_model.pseudo_r2(
y, firing_rate, score_type=score_type, aggregate_sample_scores=jnp.mean
)
assert np.allclose(sm, mn)

def test_aggregation_score_mcfadden(self, poissonGLM_model_instantiation):
X, y, model, _, firing_rate = poissonGLM_model_instantiation
sm = model.observation_model._pseudo_r2_mcfadden(y, firing_rate, aggregate_sample_scores=jnp.sum)
mn = model.observation_model._pseudo_r2_mcfadden(y, firing_rate, aggregate_sample_scores=jnp.mean)
sm = model.observation_model._pseudo_r2_mcfadden(
y, firing_rate, aggregate_sample_scores=jnp.sum
)
mn = model.observation_model._pseudo_r2_mcfadden(
y, firing_rate, aggregate_sample_scores=jnp.mean
)
assert np.allclose(sm, mn)

def test_aggregation_score_choen(self, poissonGLM_model_instantiation):
X, y, model, _, firing_rate = poissonGLM_model_instantiation
sm = model.observation_model._pseudo_r2_cohen(y, firing_rate, aggregate_sample_scores=jnp.sum)
mn = model.observation_model._pseudo_r2_cohen(y, firing_rate, aggregate_sample_scores=jnp.mean)
sm = model.observation_model._pseudo_r2_cohen(
y, firing_rate, aggregate_sample_scores=jnp.sum
)
mn = model.observation_model._pseudo_r2_cohen(
y, firing_rate, aggregate_sample_scores=jnp.mean
)
assert np.allclose(sm, mn)


Expand Down Expand Up @@ -421,7 +443,9 @@ def test_pseudo_r2_mean(self, score_type, gammaGLM_model_instantiation):
Check that the pseudo-r2 of the null model is 0.
"""
_, y, model, _, _ = gammaGLM_model_instantiation
pseudo_r2 = model.observation_model.pseudo_r2(y, y.mean(), score_type=score_type)
pseudo_r2 = model.observation_model.pseudo_r2(
y, y.mean(), score_type=score_type
)
if not np.allclose(pseudo_r2, 0, atol=10**-7, rtol=0.0):
raise ValueError(
f"pseudo-r2 of {pseudo_r2} for the null model. Should be equal to 0!"
Expand Down Expand Up @@ -503,12 +527,16 @@ def test_pseudo_r2_vs_statsmodels(self, gammaGLM_model_instantiation):

# statsmodels mcfadden
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="The InversePower link function does")
warnings.filterwarnings(
"ignore", message="The InversePower link function does"
)
mdl = sm.GLM(y, sm.add_constant(X), family=sm.families.Gamma()).fit()
pr2_sms = mdl.pseudo_rsquared("mcf")

# set params
pr2_model = model.observation_model.pseudo_r2(y, mdl.mu, score_type="pseudo-r2-McFadden", scale=mdl.scale)
pr2_model = model.observation_model.pseudo_r2(
y, mdl.mu, score_type="pseudo-r2-McFadden", scale=mdl.scale
)

if not np.allclose(pr2_model, pr2_sms):
raise ValueError("Log-likelihood doesn't match statsmodels!")
Expand All @@ -521,25 +549,41 @@ def test_aggregation_score_neg_ll(self, gammaGLM_model_instantiation):

def test_aggregation_score_ll(self, gammaGLM_model_instantiation):
X, y, model, _, firing_rate = gammaGLM_model_instantiation
sm = model.observation_model.log_likelihood(y, firing_rate, aggregate_sample_scores=jnp.sum)
mn = model.observation_model.log_likelihood(y, firing_rate, aggregate_sample_scores=jnp.mean)
sm = model.observation_model.log_likelihood(
y, firing_rate, aggregate_sample_scores=jnp.sum
)
mn = model.observation_model.log_likelihood(
y, firing_rate, aggregate_sample_scores=jnp.mean
)
assert np.allclose(sm, mn * y.shape[0])

@pytest.mark.parametrize("score_type", ["pseudo-r2-McFadden", "pseudo-r2-Cohen"])
def test_aggregation_score_pr2(self, score_type, gammaGLM_model_instantiation):
X, y, model, _, firing_rate = gammaGLM_model_instantiation
sm = model.observation_model.pseudo_r2(y, firing_rate, score_type=score_type, aggregate_sample_scores=jnp.sum)
mn = model.observation_model.pseudo_r2(y, firing_rate, score_type=score_type, aggregate_sample_scores=jnp.mean)
sm = model.observation_model.pseudo_r2(
y, firing_rate, score_type=score_type, aggregate_sample_scores=jnp.sum
)
mn = model.observation_model.pseudo_r2(
y, firing_rate, score_type=score_type, aggregate_sample_scores=jnp.mean
)
assert np.allclose(sm, mn)

def test_aggregation_score_mcfadden(self, gammaGLM_model_instantiation):
X, y, model, _, firing_rate = gammaGLM_model_instantiation
sm = model.observation_model._pseudo_r2_mcfadden(y, firing_rate, aggregate_sample_scores=jnp.sum)
mn = model.observation_model._pseudo_r2_mcfadden(y, firing_rate, aggregate_sample_scores=jnp.mean)
sm = model.observation_model._pseudo_r2_mcfadden(
y, firing_rate, aggregate_sample_scores=jnp.sum
)
mn = model.observation_model._pseudo_r2_mcfadden(
y, firing_rate, aggregate_sample_scores=jnp.mean
)
assert np.allclose(sm, mn)

def test_aggregation_score_choen(self, gammaGLM_model_instantiation):
X, y, model, _, firing_rate = gammaGLM_model_instantiation
sm = model.observation_model._pseudo_r2_cohen(y, firing_rate, aggregate_sample_scores=jnp.sum)
mn = model.observation_model._pseudo_r2_cohen(y, firing_rate, aggregate_sample_scores=jnp.mean)
sm = model.observation_model._pseudo_r2_cohen(
y, firing_rate, aggregate_sample_scores=jnp.sum
)
mn = model.observation_model._pseudo_r2_cohen(
y, firing_rate, aggregate_sample_scores=jnp.mean
)
assert np.allclose(sm, mn)
13 changes: 8 additions & 5 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_sklearn_transformer_pipeline_cv(bas, poissonGLM_model_instantiation):
bas = basis.TransformerBasis(bas)
pipe = pipeline.Pipeline([("basis", bas), ("fit", model)])
param_grid = dict(basis__n_basis_funcs=(4, 5, 10))
gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score='raise')
gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score="raise")
gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y)


Expand All @@ -62,7 +62,9 @@ def test_sklearn_transformer_pipeline_cv_multiprocess(
bas = basis.TransformerBasis(bas)
pipe = pipeline.Pipeline([("basis", bas), ("fit", model)])
param_grid = dict(basis__n_basis_funcs=(4, 5, 10))
gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, n_jobs=3, error_score='raise')
gridsearch = GridSearchCV(
pipe, param_grid=param_grid, cv=3, n_jobs=3, error_score="raise"
)
# use threading instead of fork (this avoids conflicts with jax)
with joblib.parallel_backend("threading"):
gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y)
Expand All @@ -85,7 +87,7 @@ def test_sklearn_transformer_pipeline_cv_directly_over_basis(
bas = basis.TransformerBasis(bas_cls(5))
pipe = pipeline.Pipeline([("transformerbasis", bas), ("fit", model)])
param_grid = dict(transformerbasis___basis=(bas_cls(5), bas_cls(10), bas_cls(20)))
gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score='raise')
gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score="raise")
gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y)


Expand All @@ -109,9 +111,10 @@ def test_sklearn_transformer_pipeline_cv_illegal_combination(
transformerbasis___basis=(bas_cls(5), bas_cls(10), bas_cls(20)),
transformerbasis__n_basis_funcs=(4, 5, 10),
)
gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score='raise')
gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score="raise")
with pytest.raises(
ValueError, match="Set either new _basis object or parameters for existing _basis, not both."
ValueError,
match="Set either new _basis object or parameters for existing _basis, not both.",
):
gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y)

Expand Down
24 changes: 18 additions & 6 deletions tests/test_proximal_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ def test_prox_operator_returns_tuple(prox_operator, example_data_prox_operator):


@pytest.mark.parametrize("prox_operator", [prox_group_lasso, prox_lasso])
def test_prox_operator_returns_tuple_multineuron(prox_operator, example_data_prox_operator_multineuron):
def test_prox_operator_returns_tuple_multineuron(
prox_operator, example_data_prox_operator_multineuron
):
"""Test whether the tuple returned by the proximal operator has a length of 2."""
args = example_data_prox_operator_multineuron
args = args if prox_operator is prox_group_lasso else (*args[:2], *args[3:])
Expand All @@ -32,7 +34,9 @@ def test_prox_operator_tuple_length(prox_operator, example_data_prox_operator):


@pytest.mark.parametrize("prox_operator", [prox_group_lasso, prox_lasso])
def test_prox_operator_tuple_length_multineuron(prox_operator, example_data_prox_operator_multineuron):
def test_prox_operator_tuple_length_multineuron(
prox_operator, example_data_prox_operator_multineuron
):
"""Test whether the tuple returned by the proximal operator has a length of 2."""
args = example_data_prox_operator_multineuron
args = args if prox_operator is prox_group_lasso else (*args[:2], *args[3:])
Expand All @@ -50,7 +54,9 @@ def test_prox_operator_weights_shape(prox_operator, example_data_prox_operator):


@pytest.mark.parametrize("prox_operator", [prox_group_lasso, prox_lasso])
def test_prox_operator_weights_shape_multineuron(prox_operator, example_data_prox_operator_multineuron):
def test_prox_operator_weights_shape_multineuron(
prox_operator, example_data_prox_operator_multineuron
):
"""Test whether the shape of the weights in the proximal operator is correct."""
args = example_data_prox_operator_multineuron
args = args if prox_operator is prox_group_lasso else (*args[:2], *args[3:])
Expand All @@ -68,7 +74,9 @@ def test_prox_operator_intercepts_shape(prox_operator, example_data_prox_operato


@pytest.mark.parametrize("prox_operator", [prox_group_lasso, prox_lasso])
def test_prox_operator_intercepts_shape_multineuron(prox_operator, example_data_prox_operator_multineuron):
def test_prox_operator_intercepts_shape_multineuron(
prox_operator, example_data_prox_operator_multineuron
):
"""Test whether the shape of the intercepts in the proximal operator is correct."""
args = example_data_prox_operator_multineuron
args = args if prox_operator is prox_group_lasso else (*args[:2], *args[3:])
Expand Down Expand Up @@ -104,7 +112,9 @@ def test_vmap_norm2_masked_2_non_negative(example_data_prox_operator):
assert jnp.all(l2_norm >= 0)


def test_vmap_norm2_masked_2_non_negative_multineuron(example_data_prox_operator_multineuron):
def test_vmap_norm2_masked_2_non_negative_multineuron(
example_data_prox_operator_multineuron,
):
"""Test whether all elements of the result from _vmap_norm2_masked_2 are non-negative."""
params, _, mask, _ = example_data_prox_operator_multineuron
l2_norm = _vmap_norm2_masked_2(params[0].T, mask)
Expand All @@ -119,7 +129,9 @@ def test_prox_operator_shrinks_only_masked(example_data_prox_operator):
assert all(params_new[0][i] < params[0][i] for i in [0, 2, 3])


def test_prox_operator_shrinks_only_masked_multineuron(example_data_prox_operator_multineuron):
def test_prox_operator_shrinks_only_masked_multineuron(
example_data_prox_operator_multineuron,
):
params, _, mask, _ = example_data_prox_operator_multineuron
mask = mask.astype(float)
mask = mask.at[:, 1].set(jnp.zeros(2))
Expand Down
Loading
Loading