Skip to content

Commit

Permalink
Merge pull request #258 from flatironinstitute/lint_tests
Browse files Browse the repository at this point in the history
added linting tests to tox.ini and linted
  • Loading branch information
BalzaniEdoardo authored Oct 29, 2024
2 parents b85f408 + 7d2a69b commit 6056838
Show file tree
Hide file tree
Showing 14 changed files with 332 additions and 129 deletions.
1 change: 1 addition & 0 deletions tests/test_base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def score(self, X, y, score_type="pseudo-r2-McFadden"):
def get_optimal_solver_params_config(self):
return None, None, None


class BadEstimator(Base):
def __init__(self, param1, *args):
super().__init__()
Expand Down
8 changes: 5 additions & 3 deletions tests/test_convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,14 @@ def test_ridge_convergence(solver_names):
y = np.random.poisson(rate)

# instantiate and fit ridge GLM with GradientDescent
model_GD = nmo.glm.GLM(regularizer_strength=1., regularizer="Ridge", solver_kwargs=dict(tol=10**-12))
model_GD = nmo.glm.GLM(
regularizer_strength=1.0, regularizer="Ridge", solver_kwargs=dict(tol=10**-12)
)
model_GD.fit(X, y)

# instantiate and fit ridge GLM with ProximalGradient
model_PG = nmo.glm.GLM(
regularizer_strength=1.,
regularizer_strength=1.0,
regularizer="Ridge",
solver_name="ProximalGradient",
solver_kwargs=dict(tol=10**-12),
Expand Down Expand Up @@ -109,7 +111,7 @@ def test_lasso_convergence(solver_name):
# instantiate and fit GLM with ProximalGradient
model_PG = nmo.glm.GLM(
regularizer="Lasso",
regularizer_strength=1.,
regularizer_strength=1.0,
solver_name="ProximalGradient",
solver_kwargs=dict(tol=10**-12),
)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,9 @@ def test_tree_structure_match(self, trial_counts, axis):
conv = convolve.create_convolutional_predictor(
basis_matrix, trial_counts, axis=axis
)
assert jax.tree_util.tree_structure(trial_counts) == jax.tree_util.tree_structure(conv)
assert jax.tree_util.tree_structure(
trial_counts
) == jax.tree_util.tree_structure(conv)

@pytest.mark.parametrize("axis", [0, 1, 2])
@pytest.mark.parametrize(
Expand Down
13 changes: 10 additions & 3 deletions tests/test_glm_initialization.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import warnings

import jax
import jax.numpy as jnp
import numpy as np
import pytest

import nemos as nmo
import warnings


@pytest.mark.parametrize(
"non_linearity",
Expand Down Expand Up @@ -56,14 +58,19 @@ def test_initialization_error_nan_input(non_linearity, expectation):
inverse_link_function=non_linearity, y=output_y
)


def test_initialization_error_non_invertible():
"""Initialize invalid."""
output_y = np.random.uniform(size=100)
inv_link = lambda x: jax.nn.softplus(x) + 10
with pytest.raises(ValueError, match="Failed to initialize the model intercept.+Please, provide"):
with pytest.raises(
ValueError, match="Failed to initialize the model intercept.+Please, provide"
):
with warnings.catch_warnings():
# ignore the warning raised by the root-finder (there is no root)
warnings.filterwarnings("ignore", category=RuntimeWarning, message="Tolerance of")
warnings.filterwarnings(
"ignore", category=RuntimeWarning, message="Tolerance of"
)
nmo.initialize_regressor.initialize_intercept_matching_mean_rate(
inverse_link_function=inv_link, y=output_y
)
88 changes: 66 additions & 22 deletions tests/test_observation_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ def test_pseudo_r2_range(self, score_type, poissonGLM_model_instantiation):
Compute the pseudo-r2 and check that is < 1.
"""
_, y, model, _, firing_rate = poissonGLM_model_instantiation
pseudo_r2 = model.observation_model.pseudo_r2(y, firing_rate, score_type=score_type)
pseudo_r2 = model.observation_model.pseudo_r2(
y, firing_rate, score_type=score_type
)
if (pseudo_r2 > 1) or (pseudo_r2 < 0):
raise ValueError(f"pseudo-r2 of {pseudo_r2} outside the [0,1] range!")

Expand All @@ -157,7 +159,9 @@ def test_pseudo_r2_mean(self, score_type, poissonGLM_model_instantiation):
Check that the pseudo-r2 of the null model is 0.
"""
_, y, model, _, _ = poissonGLM_model_instantiation
pseudo_r2 = model.observation_model.pseudo_r2(y, y.mean(), score_type=score_type)
pseudo_r2 = model.observation_model.pseudo_r2(
y, y.mean(), score_type=score_type
)
if not np.allclose(pseudo_r2, 0, atol=10**-7, rtol=0.0):
raise ValueError(
f"pseudo-r2 of {pseudo_r2} for the null model. Should be equal to 0!"
Expand Down Expand Up @@ -241,7 +245,9 @@ def test_pseudo_r2_vs_statsmodels(self, poissonGLM_model_instantiation):
pr2_sms = mdl.pseudo_rsquared("mcf")

# set params
pr2_model = model.observation_model.pseudo_r2(y, mdl.mu, score_type="pseudo-r2-McFadden")
pr2_model = model.observation_model.pseudo_r2(
y, mdl.mu, score_type="pseudo-r2-McFadden"
)

if not np.allclose(pr2_model, pr2_sms):
raise ValueError("Log-likelihood doesn't match statsmodels!")
Expand All @@ -254,27 +260,43 @@ def test_aggregation_score_neg_ll(self, poissonGLM_model_instantiation):

def test_aggregation_score_ll(self, poissonGLM_model_instantiation):
X, y, model, _, firing_rate = poissonGLM_model_instantiation
sm = model.observation_model.log_likelihood(y, firing_rate, aggregate_sample_scores=jnp.sum)
mn = model.observation_model.log_likelihood(y, firing_rate, aggregate_sample_scores=jnp.mean)
sm = model.observation_model.log_likelihood(
y, firing_rate, aggregate_sample_scores=jnp.sum
)
mn = model.observation_model.log_likelihood(
y, firing_rate, aggregate_sample_scores=jnp.mean
)
assert np.allclose(sm, mn * y.shape[0])

@pytest.mark.parametrize("score_type", ["pseudo-r2-McFadden", "pseudo-r2-Cohen"])
def test_aggregation_score_pr2(self, score_type, poissonGLM_model_instantiation):
X, y, model, _, firing_rate = poissonGLM_model_instantiation
sm = model.observation_model.pseudo_r2(y, firing_rate, score_type=score_type, aggregate_sample_scores=jnp.sum)
mn = model.observation_model.pseudo_r2(y, firing_rate, score_type=score_type, aggregate_sample_scores=jnp.mean)
sm = model.observation_model.pseudo_r2(
y, firing_rate, score_type=score_type, aggregate_sample_scores=jnp.sum
)
mn = model.observation_model.pseudo_r2(
y, firing_rate, score_type=score_type, aggregate_sample_scores=jnp.mean
)
assert np.allclose(sm, mn)

def test_aggregation_score_mcfadden(self, poissonGLM_model_instantiation):
X, y, model, _, firing_rate = poissonGLM_model_instantiation
sm = model.observation_model._pseudo_r2_mcfadden(y, firing_rate, aggregate_sample_scores=jnp.sum)
mn = model.observation_model._pseudo_r2_mcfadden(y, firing_rate, aggregate_sample_scores=jnp.mean)
sm = model.observation_model._pseudo_r2_mcfadden(
y, firing_rate, aggregate_sample_scores=jnp.sum
)
mn = model.observation_model._pseudo_r2_mcfadden(
y, firing_rate, aggregate_sample_scores=jnp.mean
)
assert np.allclose(sm, mn)

def test_aggregation_score_choen(self, poissonGLM_model_instantiation):
X, y, model, _, firing_rate = poissonGLM_model_instantiation
sm = model.observation_model._pseudo_r2_cohen(y, firing_rate, aggregate_sample_scores=jnp.sum)
mn = model.observation_model._pseudo_r2_cohen(y, firing_rate, aggregate_sample_scores=jnp.mean)
sm = model.observation_model._pseudo_r2_cohen(
y, firing_rate, aggregate_sample_scores=jnp.sum
)
mn = model.observation_model._pseudo_r2_cohen(
y, firing_rate, aggregate_sample_scores=jnp.mean
)
assert np.allclose(sm, mn)


Expand Down Expand Up @@ -421,7 +443,9 @@ def test_pseudo_r2_mean(self, score_type, gammaGLM_model_instantiation):
Check that the pseudo-r2 of the null model is 0.
"""
_, y, model, _, _ = gammaGLM_model_instantiation
pseudo_r2 = model.observation_model.pseudo_r2(y, y.mean(), score_type=score_type)
pseudo_r2 = model.observation_model.pseudo_r2(
y, y.mean(), score_type=score_type
)
if not np.allclose(pseudo_r2, 0, atol=10**-7, rtol=0.0):
raise ValueError(
f"pseudo-r2 of {pseudo_r2} for the null model. Should be equal to 0!"
Expand Down Expand Up @@ -503,12 +527,16 @@ def test_pseudo_r2_vs_statsmodels(self, gammaGLM_model_instantiation):

# statsmodels mcfadden
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="The InversePower link function does")
warnings.filterwarnings(
"ignore", message="The InversePower link function does"
)
mdl = sm.GLM(y, sm.add_constant(X), family=sm.families.Gamma()).fit()
pr2_sms = mdl.pseudo_rsquared("mcf")

# set params
pr2_model = model.observation_model.pseudo_r2(y, mdl.mu, score_type="pseudo-r2-McFadden", scale=mdl.scale)
pr2_model = model.observation_model.pseudo_r2(
y, mdl.mu, score_type="pseudo-r2-McFadden", scale=mdl.scale
)

if not np.allclose(pr2_model, pr2_sms):
raise ValueError("Log-likelihood doesn't match statsmodels!")
Expand All @@ -521,25 +549,41 @@ def test_aggregation_score_neg_ll(self, gammaGLM_model_instantiation):

def test_aggregation_score_ll(self, gammaGLM_model_instantiation):
X, y, model, _, firing_rate = gammaGLM_model_instantiation
sm = model.observation_model.log_likelihood(y, firing_rate, aggregate_sample_scores=jnp.sum)
mn = model.observation_model.log_likelihood(y, firing_rate, aggregate_sample_scores=jnp.mean)
sm = model.observation_model.log_likelihood(
y, firing_rate, aggregate_sample_scores=jnp.sum
)
mn = model.observation_model.log_likelihood(
y, firing_rate, aggregate_sample_scores=jnp.mean
)
assert np.allclose(sm, mn * y.shape[0])

@pytest.mark.parametrize("score_type", ["pseudo-r2-McFadden", "pseudo-r2-Cohen"])
def test_aggregation_score_pr2(self, score_type, gammaGLM_model_instantiation):
X, y, model, _, firing_rate = gammaGLM_model_instantiation
sm = model.observation_model.pseudo_r2(y, firing_rate, score_type=score_type, aggregate_sample_scores=jnp.sum)
mn = model.observation_model.pseudo_r2(y, firing_rate, score_type=score_type, aggregate_sample_scores=jnp.mean)
sm = model.observation_model.pseudo_r2(
y, firing_rate, score_type=score_type, aggregate_sample_scores=jnp.sum
)
mn = model.observation_model.pseudo_r2(
y, firing_rate, score_type=score_type, aggregate_sample_scores=jnp.mean
)
assert np.allclose(sm, mn)

def test_aggregation_score_mcfadden(self, gammaGLM_model_instantiation):
X, y, model, _, firing_rate = gammaGLM_model_instantiation
sm = model.observation_model._pseudo_r2_mcfadden(y, firing_rate, aggregate_sample_scores=jnp.sum)
mn = model.observation_model._pseudo_r2_mcfadden(y, firing_rate, aggregate_sample_scores=jnp.mean)
sm = model.observation_model._pseudo_r2_mcfadden(
y, firing_rate, aggregate_sample_scores=jnp.sum
)
mn = model.observation_model._pseudo_r2_mcfadden(
y, firing_rate, aggregate_sample_scores=jnp.mean
)
assert np.allclose(sm, mn)

def test_aggregation_score_choen(self, gammaGLM_model_instantiation):
X, y, model, _, firing_rate = gammaGLM_model_instantiation
sm = model.observation_model._pseudo_r2_cohen(y, firing_rate, aggregate_sample_scores=jnp.sum)
mn = model.observation_model._pseudo_r2_cohen(y, firing_rate, aggregate_sample_scores=jnp.mean)
sm = model.observation_model._pseudo_r2_cohen(
y, firing_rate, aggregate_sample_scores=jnp.sum
)
mn = model.observation_model._pseudo_r2_cohen(
y, firing_rate, aggregate_sample_scores=jnp.mean
)
assert np.allclose(sm, mn)
13 changes: 8 additions & 5 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_sklearn_transformer_pipeline_cv(bas, poissonGLM_model_instantiation):
bas = basis.TransformerBasis(bas)
pipe = pipeline.Pipeline([("basis", bas), ("fit", model)])
param_grid = dict(basis__n_basis_funcs=(4, 5, 10))
gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score='raise')
gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score="raise")
gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y)


Expand All @@ -62,7 +62,9 @@ def test_sklearn_transformer_pipeline_cv_multiprocess(
bas = basis.TransformerBasis(bas)
pipe = pipeline.Pipeline([("basis", bas), ("fit", model)])
param_grid = dict(basis__n_basis_funcs=(4, 5, 10))
gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, n_jobs=3, error_score='raise')
gridsearch = GridSearchCV(
pipe, param_grid=param_grid, cv=3, n_jobs=3, error_score="raise"
)
# use threading instead of fork (this avoids conflicts with jax)
with joblib.parallel_backend("threading"):
gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y)
Expand All @@ -85,7 +87,7 @@ def test_sklearn_transformer_pipeline_cv_directly_over_basis(
bas = basis.TransformerBasis(bas_cls(5))
pipe = pipeline.Pipeline([("transformerbasis", bas), ("fit", model)])
param_grid = dict(transformerbasis___basis=(bas_cls(5), bas_cls(10), bas_cls(20)))
gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score='raise')
gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score="raise")
gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y)


Expand All @@ -109,9 +111,10 @@ def test_sklearn_transformer_pipeline_cv_illegal_combination(
transformerbasis___basis=(bas_cls(5), bas_cls(10), bas_cls(20)),
transformerbasis__n_basis_funcs=(4, 5, 10),
)
gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score='raise')
gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score="raise")
with pytest.raises(
ValueError, match="Set either new _basis object or parameters for existing _basis, not both."
ValueError,
match="Set either new _basis object or parameters for existing _basis, not both.",
):
gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y)

Expand Down
24 changes: 18 additions & 6 deletions tests/test_proximal_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ def test_prox_operator_returns_tuple(prox_operator, example_data_prox_operator):


@pytest.mark.parametrize("prox_operator", [prox_group_lasso, prox_lasso])
def test_prox_operator_returns_tuple_multineuron(prox_operator, example_data_prox_operator_multineuron):
def test_prox_operator_returns_tuple_multineuron(
prox_operator, example_data_prox_operator_multineuron
):
"""Test whether the tuple returned by the proximal operator has a length of 2."""
args = example_data_prox_operator_multineuron
args = args if prox_operator is prox_group_lasso else (*args[:2], *args[3:])
Expand All @@ -32,7 +34,9 @@ def test_prox_operator_tuple_length(prox_operator, example_data_prox_operator):


@pytest.mark.parametrize("prox_operator", [prox_group_lasso, prox_lasso])
def test_prox_operator_tuple_length_multineuron(prox_operator, example_data_prox_operator_multineuron):
def test_prox_operator_tuple_length_multineuron(
prox_operator, example_data_prox_operator_multineuron
):
"""Test whether the tuple returned by the proximal operator has a length of 2."""
args = example_data_prox_operator_multineuron
args = args if prox_operator is prox_group_lasso else (*args[:2], *args[3:])
Expand All @@ -50,7 +54,9 @@ def test_prox_operator_weights_shape(prox_operator, example_data_prox_operator):


@pytest.mark.parametrize("prox_operator", [prox_group_lasso, prox_lasso])
def test_prox_operator_weights_shape_multineuron(prox_operator, example_data_prox_operator_multineuron):
def test_prox_operator_weights_shape_multineuron(
prox_operator, example_data_prox_operator_multineuron
):
"""Test whether the shape of the weights in the proximal operator is correct."""
args = example_data_prox_operator_multineuron
args = args if prox_operator is prox_group_lasso else (*args[:2], *args[3:])
Expand All @@ -68,7 +74,9 @@ def test_prox_operator_intercepts_shape(prox_operator, example_data_prox_operato


@pytest.mark.parametrize("prox_operator", [prox_group_lasso, prox_lasso])
def test_prox_operator_intercepts_shape_multineuron(prox_operator, example_data_prox_operator_multineuron):
def test_prox_operator_intercepts_shape_multineuron(
prox_operator, example_data_prox_operator_multineuron
):
"""Test whether the shape of the intercepts in the proximal operator is correct."""
args = example_data_prox_operator_multineuron
args = args if prox_operator is prox_group_lasso else (*args[:2], *args[3:])
Expand Down Expand Up @@ -104,7 +112,9 @@ def test_vmap_norm2_masked_2_non_negative(example_data_prox_operator):
assert jnp.all(l2_norm >= 0)


def test_vmap_norm2_masked_2_non_negative_multineuron(example_data_prox_operator_multineuron):
def test_vmap_norm2_masked_2_non_negative_multineuron(
example_data_prox_operator_multineuron,
):
"""Test whether all elements of the result from _vmap_norm2_masked_2 are non-negative."""
params, _, mask, _ = example_data_prox_operator_multineuron
l2_norm = _vmap_norm2_masked_2(params[0].T, mask)
Expand All @@ -119,7 +129,9 @@ def test_prox_operator_shrinks_only_masked(example_data_prox_operator):
assert all(params_new[0][i] < params[0][i] for i in [0, 2, 3])


def test_prox_operator_shrinks_only_masked_multineuron(example_data_prox_operator_multineuron):
def test_prox_operator_shrinks_only_masked_multineuron(
example_data_prox_operator_multineuron,
):
params, _, mask, _ = example_data_prox_operator_multineuron
mask = mask.astype(float)
mask = mask.at[:, 1].set(jnp.zeros(2))
Expand Down
Loading

0 comments on commit 6056838

Please sign in to comment.