diff --git a/tests/test_base_class.py b/tests/test_base_class.py index 1f89f7d6..2324434b 100644 --- a/tests/test_base_class.py +++ b/tests/test_base_class.py @@ -27,6 +27,7 @@ def score(self, X, y, score_type="pseudo-r2-McFadden"): def get_optimal_solver_params_config(self): return None, None, None + class BadEstimator(Base): def __init__(self, param1, *args): super().__init__() diff --git a/tests/test_convergence.py b/tests/test_convergence.py index 5a8814e1..f7629dcf 100644 --- a/tests/test_convergence.py +++ b/tests/test_convergence.py @@ -76,12 +76,14 @@ def test_ridge_convergence(solver_names): y = np.random.poisson(rate) # instantiate and fit ridge GLM with GradientDescent - model_GD = nmo.glm.GLM(regularizer_strength=1., regularizer="Ridge", solver_kwargs=dict(tol=10**-12)) + model_GD = nmo.glm.GLM( + regularizer_strength=1.0, regularizer="Ridge", solver_kwargs=dict(tol=10**-12) + ) model_GD.fit(X, y) # instantiate and fit ridge GLM with ProximalGradient model_PG = nmo.glm.GLM( - regularizer_strength=1., + regularizer_strength=1.0, regularizer="Ridge", solver_name="ProximalGradient", solver_kwargs=dict(tol=10**-12), @@ -109,7 +111,7 @@ def test_lasso_convergence(solver_name): # instantiate and fit GLM with ProximalGradient model_PG = nmo.glm.GLM( regularizer="Lasso", - regularizer_strength=1., + regularizer_strength=1.0, solver_name="ProximalGradient", solver_kwargs=dict(tol=10**-12), ) diff --git a/tests/test_convolution.py b/tests/test_convolution.py index 4c5b75ec..8fb5aecf 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -329,7 +329,9 @@ def test_tree_structure_match(self, trial_counts, axis): conv = convolve.create_convolutional_predictor( basis_matrix, trial_counts, axis=axis ) - assert jax.tree_util.tree_structure(trial_counts) == jax.tree_util.tree_structure(conv) + assert jax.tree_util.tree_structure( + trial_counts + ) == jax.tree_util.tree_structure(conv) @pytest.mark.parametrize("axis", [0, 1, 2]) @pytest.mark.parametrize( diff --git a/tests/test_glm_initialization.py b/tests/test_glm_initialization.py index 972b98ae..fdef35d5 100644 --- a/tests/test_glm_initialization.py +++ b/tests/test_glm_initialization.py @@ -1,10 +1,12 @@ +import warnings + import jax import jax.numpy as jnp import numpy as np import pytest import nemos as nmo -import warnings + @pytest.mark.parametrize( "non_linearity", @@ -56,14 +58,19 @@ def test_initialization_error_nan_input(non_linearity, expectation): inverse_link_function=non_linearity, y=output_y ) + def test_initialization_error_non_invertible(): """Initialize invalid.""" output_y = np.random.uniform(size=100) inv_link = lambda x: jax.nn.softplus(x) + 10 - with pytest.raises(ValueError, match="Failed to initialize the model intercept.+Please, provide"): + with pytest.raises( + ValueError, match="Failed to initialize the model intercept.+Please, provide" + ): with warnings.catch_warnings(): # ignore the warning raised by the root-finder (there is no root) - warnings.filterwarnings("ignore", category=RuntimeWarning, message="Tolerance of") + warnings.filterwarnings( + "ignore", category=RuntimeWarning, message="Tolerance of" + ) nmo.initialize_regressor.initialize_intercept_matching_mean_rate( inverse_link_function=inv_link, y=output_y ) diff --git a/tests/test_observation_models.py b/tests/test_observation_models.py index 25262035..087c8091 100644 --- a/tests/test_observation_models.py +++ b/tests/test_observation_models.py @@ -147,7 +147,9 @@ def test_pseudo_r2_range(self, score_type, poissonGLM_model_instantiation): Compute the pseudo-r2 and check that is < 1. """ _, y, model, _, firing_rate = poissonGLM_model_instantiation - pseudo_r2 = model.observation_model.pseudo_r2(y, firing_rate, score_type=score_type) + pseudo_r2 = model.observation_model.pseudo_r2( + y, firing_rate, score_type=score_type + ) if (pseudo_r2 > 1) or (pseudo_r2 < 0): raise ValueError(f"pseudo-r2 of {pseudo_r2} outside the [0,1] range!") @@ -157,7 +159,9 @@ def test_pseudo_r2_mean(self, score_type, poissonGLM_model_instantiation): Check that the pseudo-r2 of the null model is 0. """ _, y, model, _, _ = poissonGLM_model_instantiation - pseudo_r2 = model.observation_model.pseudo_r2(y, y.mean(), score_type=score_type) + pseudo_r2 = model.observation_model.pseudo_r2( + y, y.mean(), score_type=score_type + ) if not np.allclose(pseudo_r2, 0, atol=10**-7, rtol=0.0): raise ValueError( f"pseudo-r2 of {pseudo_r2} for the null model. Should be equal to 0!" @@ -241,7 +245,9 @@ def test_pseudo_r2_vs_statsmodels(self, poissonGLM_model_instantiation): pr2_sms = mdl.pseudo_rsquared("mcf") # set params - pr2_model = model.observation_model.pseudo_r2(y, mdl.mu, score_type="pseudo-r2-McFadden") + pr2_model = model.observation_model.pseudo_r2( + y, mdl.mu, score_type="pseudo-r2-McFadden" + ) if not np.allclose(pr2_model, pr2_sms): raise ValueError("Log-likelihood doesn't match statsmodels!") @@ -254,27 +260,43 @@ def test_aggregation_score_neg_ll(self, poissonGLM_model_instantiation): def test_aggregation_score_ll(self, poissonGLM_model_instantiation): X, y, model, _, firing_rate = poissonGLM_model_instantiation - sm = model.observation_model.log_likelihood(y, firing_rate, aggregate_sample_scores=jnp.sum) - mn = model.observation_model.log_likelihood(y, firing_rate, aggregate_sample_scores=jnp.mean) + sm = model.observation_model.log_likelihood( + y, firing_rate, aggregate_sample_scores=jnp.sum + ) + mn = model.observation_model.log_likelihood( + y, firing_rate, aggregate_sample_scores=jnp.mean + ) assert np.allclose(sm, mn * y.shape[0]) @pytest.mark.parametrize("score_type", ["pseudo-r2-McFadden", "pseudo-r2-Cohen"]) def test_aggregation_score_pr2(self, score_type, poissonGLM_model_instantiation): X, y, model, _, firing_rate = poissonGLM_model_instantiation - sm = model.observation_model.pseudo_r2(y, firing_rate, score_type=score_type, aggregate_sample_scores=jnp.sum) - mn = model.observation_model.pseudo_r2(y, firing_rate, score_type=score_type, aggregate_sample_scores=jnp.mean) + sm = model.observation_model.pseudo_r2( + y, firing_rate, score_type=score_type, aggregate_sample_scores=jnp.sum + ) + mn = model.observation_model.pseudo_r2( + y, firing_rate, score_type=score_type, aggregate_sample_scores=jnp.mean + ) assert np.allclose(sm, mn) def test_aggregation_score_mcfadden(self, poissonGLM_model_instantiation): X, y, model, _, firing_rate = poissonGLM_model_instantiation - sm = model.observation_model._pseudo_r2_mcfadden(y, firing_rate, aggregate_sample_scores=jnp.sum) - mn = model.observation_model._pseudo_r2_mcfadden(y, firing_rate, aggregate_sample_scores=jnp.mean) + sm = model.observation_model._pseudo_r2_mcfadden( + y, firing_rate, aggregate_sample_scores=jnp.sum + ) + mn = model.observation_model._pseudo_r2_mcfadden( + y, firing_rate, aggregate_sample_scores=jnp.mean + ) assert np.allclose(sm, mn) def test_aggregation_score_choen(self, poissonGLM_model_instantiation): X, y, model, _, firing_rate = poissonGLM_model_instantiation - sm = model.observation_model._pseudo_r2_cohen(y, firing_rate, aggregate_sample_scores=jnp.sum) - mn = model.observation_model._pseudo_r2_cohen(y, firing_rate, aggregate_sample_scores=jnp.mean) + sm = model.observation_model._pseudo_r2_cohen( + y, firing_rate, aggregate_sample_scores=jnp.sum + ) + mn = model.observation_model._pseudo_r2_cohen( + y, firing_rate, aggregate_sample_scores=jnp.mean + ) assert np.allclose(sm, mn) @@ -421,7 +443,9 @@ def test_pseudo_r2_mean(self, score_type, gammaGLM_model_instantiation): Check that the pseudo-r2 of the null model is 0. """ _, y, model, _, _ = gammaGLM_model_instantiation - pseudo_r2 = model.observation_model.pseudo_r2(y, y.mean(), score_type=score_type) + pseudo_r2 = model.observation_model.pseudo_r2( + y, y.mean(), score_type=score_type + ) if not np.allclose(pseudo_r2, 0, atol=10**-7, rtol=0.0): raise ValueError( f"pseudo-r2 of {pseudo_r2} for the null model. Should be equal to 0!" @@ -503,12 +527,16 @@ def test_pseudo_r2_vs_statsmodels(self, gammaGLM_model_instantiation): # statsmodels mcfadden with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="The InversePower link function does") + warnings.filterwarnings( + "ignore", message="The InversePower link function does" + ) mdl = sm.GLM(y, sm.add_constant(X), family=sm.families.Gamma()).fit() pr2_sms = mdl.pseudo_rsquared("mcf") # set params - pr2_model = model.observation_model.pseudo_r2(y, mdl.mu, score_type="pseudo-r2-McFadden", scale=mdl.scale) + pr2_model = model.observation_model.pseudo_r2( + y, mdl.mu, score_type="pseudo-r2-McFadden", scale=mdl.scale + ) if not np.allclose(pr2_model, pr2_sms): raise ValueError("Log-likelihood doesn't match statsmodels!") @@ -521,25 +549,41 @@ def test_aggregation_score_neg_ll(self, gammaGLM_model_instantiation): def test_aggregation_score_ll(self, gammaGLM_model_instantiation): X, y, model, _, firing_rate = gammaGLM_model_instantiation - sm = model.observation_model.log_likelihood(y, firing_rate, aggregate_sample_scores=jnp.sum) - mn = model.observation_model.log_likelihood(y, firing_rate, aggregate_sample_scores=jnp.mean) + sm = model.observation_model.log_likelihood( + y, firing_rate, aggregate_sample_scores=jnp.sum + ) + mn = model.observation_model.log_likelihood( + y, firing_rate, aggregate_sample_scores=jnp.mean + ) assert np.allclose(sm, mn * y.shape[0]) @pytest.mark.parametrize("score_type", ["pseudo-r2-McFadden", "pseudo-r2-Cohen"]) def test_aggregation_score_pr2(self, score_type, gammaGLM_model_instantiation): X, y, model, _, firing_rate = gammaGLM_model_instantiation - sm = model.observation_model.pseudo_r2(y, firing_rate, score_type=score_type, aggregate_sample_scores=jnp.sum) - mn = model.observation_model.pseudo_r2(y, firing_rate, score_type=score_type, aggregate_sample_scores=jnp.mean) + sm = model.observation_model.pseudo_r2( + y, firing_rate, score_type=score_type, aggregate_sample_scores=jnp.sum + ) + mn = model.observation_model.pseudo_r2( + y, firing_rate, score_type=score_type, aggregate_sample_scores=jnp.mean + ) assert np.allclose(sm, mn) def test_aggregation_score_mcfadden(self, gammaGLM_model_instantiation): X, y, model, _, firing_rate = gammaGLM_model_instantiation - sm = model.observation_model._pseudo_r2_mcfadden(y, firing_rate, aggregate_sample_scores=jnp.sum) - mn = model.observation_model._pseudo_r2_mcfadden(y, firing_rate, aggregate_sample_scores=jnp.mean) + sm = model.observation_model._pseudo_r2_mcfadden( + y, firing_rate, aggregate_sample_scores=jnp.sum + ) + mn = model.observation_model._pseudo_r2_mcfadden( + y, firing_rate, aggregate_sample_scores=jnp.mean + ) assert np.allclose(sm, mn) def test_aggregation_score_choen(self, gammaGLM_model_instantiation): X, y, model, _, firing_rate = gammaGLM_model_instantiation - sm = model.observation_model._pseudo_r2_cohen(y, firing_rate, aggregate_sample_scores=jnp.sum) - mn = model.observation_model._pseudo_r2_cohen(y, firing_rate, aggregate_sample_scores=jnp.mean) + sm = model.observation_model._pseudo_r2_cohen( + y, firing_rate, aggregate_sample_scores=jnp.sum + ) + mn = model.observation_model._pseudo_r2_cohen( + y, firing_rate, aggregate_sample_scores=jnp.mean + ) assert np.allclose(sm, mn) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 26b58d14..703d42ff 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -41,7 +41,7 @@ def test_sklearn_transformer_pipeline_cv(bas, poissonGLM_model_instantiation): bas = basis.TransformerBasis(bas) pipe = pipeline.Pipeline([("basis", bas), ("fit", model)]) param_grid = dict(basis__n_basis_funcs=(4, 5, 10)) - gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score='raise') + gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score="raise") gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y) @@ -62,7 +62,9 @@ def test_sklearn_transformer_pipeline_cv_multiprocess( bas = basis.TransformerBasis(bas) pipe = pipeline.Pipeline([("basis", bas), ("fit", model)]) param_grid = dict(basis__n_basis_funcs=(4, 5, 10)) - gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, n_jobs=3, error_score='raise') + gridsearch = GridSearchCV( + pipe, param_grid=param_grid, cv=3, n_jobs=3, error_score="raise" + ) # use threading instead of fork (this avoids conflicts with jax) with joblib.parallel_backend("threading"): gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y) @@ -85,7 +87,7 @@ def test_sklearn_transformer_pipeline_cv_directly_over_basis( bas = basis.TransformerBasis(bas_cls(5)) pipe = pipeline.Pipeline([("transformerbasis", bas), ("fit", model)]) param_grid = dict(transformerbasis___basis=(bas_cls(5), bas_cls(10), bas_cls(20))) - gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score='raise') + gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score="raise") gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y) @@ -109,9 +111,10 @@ def test_sklearn_transformer_pipeline_cv_illegal_combination( transformerbasis___basis=(bas_cls(5), bas_cls(10), bas_cls(20)), transformerbasis__n_basis_funcs=(4, 5, 10), ) - gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score='raise') + gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score="raise") with pytest.raises( - ValueError, match="Set either new _basis object or parameters for existing _basis, not both." + ValueError, + match="Set either new _basis object or parameters for existing _basis, not both.", ): gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y) diff --git a/tests/test_proximal_operator.py b/tests/test_proximal_operator.py index a6a65bfb..bda0d3e6 100644 --- a/tests/test_proximal_operator.py +++ b/tests/test_proximal_operator.py @@ -14,7 +14,9 @@ def test_prox_operator_returns_tuple(prox_operator, example_data_prox_operator): @pytest.mark.parametrize("prox_operator", [prox_group_lasso, prox_lasso]) -def test_prox_operator_returns_tuple_multineuron(prox_operator, example_data_prox_operator_multineuron): +def test_prox_operator_returns_tuple_multineuron( + prox_operator, example_data_prox_operator_multineuron +): """Test whether the tuple returned by the proximal operator has a length of 2.""" args = example_data_prox_operator_multineuron args = args if prox_operator is prox_group_lasso else (*args[:2], *args[3:]) @@ -32,7 +34,9 @@ def test_prox_operator_tuple_length(prox_operator, example_data_prox_operator): @pytest.mark.parametrize("prox_operator", [prox_group_lasso, prox_lasso]) -def test_prox_operator_tuple_length_multineuron(prox_operator, example_data_prox_operator_multineuron): +def test_prox_operator_tuple_length_multineuron( + prox_operator, example_data_prox_operator_multineuron +): """Test whether the tuple returned by the proximal operator has a length of 2.""" args = example_data_prox_operator_multineuron args = args if prox_operator is prox_group_lasso else (*args[:2], *args[3:]) @@ -50,7 +54,9 @@ def test_prox_operator_weights_shape(prox_operator, example_data_prox_operator): @pytest.mark.parametrize("prox_operator", [prox_group_lasso, prox_lasso]) -def test_prox_operator_weights_shape_multineuron(prox_operator, example_data_prox_operator_multineuron): +def test_prox_operator_weights_shape_multineuron( + prox_operator, example_data_prox_operator_multineuron +): """Test whether the shape of the weights in the proximal operator is correct.""" args = example_data_prox_operator_multineuron args = args if prox_operator is prox_group_lasso else (*args[:2], *args[3:]) @@ -68,7 +74,9 @@ def test_prox_operator_intercepts_shape(prox_operator, example_data_prox_operato @pytest.mark.parametrize("prox_operator", [prox_group_lasso, prox_lasso]) -def test_prox_operator_intercepts_shape_multineuron(prox_operator, example_data_prox_operator_multineuron): +def test_prox_operator_intercepts_shape_multineuron( + prox_operator, example_data_prox_operator_multineuron +): """Test whether the shape of the intercepts in the proximal operator is correct.""" args = example_data_prox_operator_multineuron args = args if prox_operator is prox_group_lasso else (*args[:2], *args[3:]) @@ -104,7 +112,9 @@ def test_vmap_norm2_masked_2_non_negative(example_data_prox_operator): assert jnp.all(l2_norm >= 0) -def test_vmap_norm2_masked_2_non_negative_multineuron(example_data_prox_operator_multineuron): +def test_vmap_norm2_masked_2_non_negative_multineuron( + example_data_prox_operator_multineuron, +): """Test whether all elements of the result from _vmap_norm2_masked_2 are non-negative.""" params, _, mask, _ = example_data_prox_operator_multineuron l2_norm = _vmap_norm2_masked_2(params[0].T, mask) @@ -119,7 +129,9 @@ def test_prox_operator_shrinks_only_masked(example_data_prox_operator): assert all(params_new[0][i] < params[0][i] for i in [0, 2, 3]) -def test_prox_operator_shrinks_only_masked_multineuron(example_data_prox_operator_multineuron): +def test_prox_operator_shrinks_only_masked_multineuron( + example_data_prox_operator_multineuron, +): params, _, mask, _ = example_data_prox_operator_multineuron mask = mask.astype(float) mask = mask.at[:, 1].set(jnp.zeros(2)) diff --git a/tests/test_regularizer.py b/tests/test_regularizer.py index 5abba876..96e1b41f 100644 --- a/tests/test_regularizer.py +++ b/tests/test_regularizer.py @@ -222,7 +222,7 @@ def test_regularizer_strength_none(self): with pytest.warns(UserWarning): model.regularizer = regularizer - assert model.regularizer_strength == 1. + assert model.regularizer_strength == 1.0 def test_get_params(self): """Test get_params() returns expected values.""" @@ -406,7 +406,9 @@ def test_solver_match_statsmodels_gamma( model._initialize_parameters(X, y), X, y )[0] with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="The InversePower link function does ") + warnings.filterwarnings( + "ignore", message="The InversePower link function does " + ) model_sm = sm.GLM( endog=y, exog=sm.add_constant(X), family=sm.families.Gamma(link=link_sm) ) @@ -468,9 +470,17 @@ def test_init_solver_name(self, solver_name): with pytest.raises( ValueError, match=f"The solver: {solver_name} is not allowed for " ): - nmo.glm.GLM(regularizer=self.cls(), solver_name=solver_name, regularizer_strength=1.) + nmo.glm.GLM( + regularizer=self.cls(), + solver_name=solver_name, + regularizer_strength=1.0, + ) else: - nmo.glm.GLM(regularizer=self.cls(), solver_name=solver_name, regularizer_strength=1.) + nmo.glm.GLM( + regularizer=self.cls(), + solver_name=solver_name, + regularizer_strength=1.0, + ) @pytest.mark.parametrize( "solver_name", @@ -497,7 +507,7 @@ def test_set_solver_name_allowed(self, solver_name): "ProxSVRG", ] regularizer = self.cls() - model = nmo.glm.GLM(regularizer=regularizer, regularizer_strength=1.) + model = nmo.glm.GLM(regularizer=regularizer, regularizer_strength=1.0) raise_exception = solver_name not in acceptable_solvers if raise_exception: with pytest.raises( @@ -521,14 +531,14 @@ def test_init_solver_kwargs(self, solver_name, solver_kwargs): regularizer=regularizer, solver_name=solver_name, solver_kwargs=solver_kwargs, - regularizer_strength=1. + regularizer_strength=1.0, ) else: nmo.glm.GLM( regularizer=regularizer, solver_name=solver_name, solver_kwargs=solver_kwargs, - regularizer_strength=1. + regularizer_strength=1.0, ) def test_regularizer_strength_none(self): @@ -561,7 +571,7 @@ def test_loss_is_callable(self, loss): """Test Ridge callable loss.""" raise_exception = not callable(loss) regularizer = self.cls() - model = nmo.glm.GLM(regularizer=regularizer, regularizer_strength=1.) + model = nmo.glm.GLM(regularizer=regularizer, regularizer_strength=1.0) model._predict_and_compute_loss = loss if raise_exception: with pytest.raises(TypeError, match="The `loss` must be a Callable"): @@ -579,7 +589,7 @@ def test_run_solver(self, solver_name, poissonGLM_model_instantiation): X, y, model, true_params, firing_rate = poissonGLM_model_instantiation # set regularizer and solver name - model.set_params(regularizer=self.cls(), regularizer_strength=1.) + model.set_params(regularizer=self.cls(), regularizer_strength=1.0) model.solver_name = solver_name runner = model.instantiate_solver().solver_run runner((true_params[0] * 0.0, true_params[1]), X, y) @@ -594,7 +604,7 @@ def test_run_solver_tree(self, solver_name, poissonGLM_model_instantiation_pytre X, y, model, true_params, firing_rate = poissonGLM_model_instantiation_pytree # set regularizer and solver name - model.set_params(regularizer=self.cls(), regularizer_strength=1.) + model.set_params(regularizer=self.cls(), regularizer_strength=1.0) model.solver_name = solver_name runner = model.instantiate_solver().solver_run runner( @@ -612,7 +622,7 @@ def test_solver_output_match(self, poissonGLM_model_instantiation, solver_name): model.data_type = jnp.float64 # set model params - model.set_params(regularizer=self.cls(), regularizer_strength=1.) + model.set_params(regularizer=self.cls(), regularizer_strength=1.0) model.solver_name = solver_name model.solver_kwargs = {"tol": 10**-12} @@ -643,7 +653,7 @@ def test_solver_match_sklearn(self, poissonGLM_model_instantiation): X, y, model, true_params, firing_rate = poissonGLM_model_instantiation # set precision to float64 for accurate matching of the results model.data_type = jnp.float64 - model.set_params(regularizer=self.cls(), regularizer_strength=1.) + model.set_params(regularizer=self.cls(), regularizer_strength=1.0) model.solver_kwargs = {"tol": 10**-12} model.solver_name = "BFGS" @@ -670,7 +680,7 @@ def test_solver_match_sklearn_gamma(self, gammaGLM_model_instantiation): # set precision to float64 for accurate matching of the results model.data_type = jnp.float64 model.observation_model.inverse_link_function = jnp.exp - model.set_params(regularizer=self.cls(), regularizer_strength=1.) + model.set_params(regularizer=self.cls(), regularizer_strength=1.0) model.solver_kwargs = {"tol": 10**-12} model.regularizer_strength = 0.1 model.solver_name = "BFGS" @@ -702,7 +712,7 @@ def test_solver_match_sklearn_gamma(self, gammaGLM_model_instantiation): ) def test_solver_combination(self, solver_name, poissonGLM_model_instantiation): X, y, model, true_params, firing_rate = poissonGLM_model_instantiation - model.set_params(regularizer=self.cls(), regularizer_strength=1.) + model.set_params(regularizer=self.cls(), regularizer_strength=1.0) model.solver_name = solver_name model.fit(X, y) @@ -733,9 +743,15 @@ def test_init_solver_name(self, solver_name): with pytest.raises( ValueError, match=f"The solver: {solver_name} is not allowed for " ): - nmo.glm.GLM(regularizer=self.cls(), solver_name=solver_name, regularizer_strength=1) + nmo.glm.GLM( + regularizer=self.cls(), + solver_name=solver_name, + regularizer_strength=1, + ) else: - nmo.glm.GLM(regularizer=self.cls(), solver_name=solver_name, regularizer_strength=1) + nmo.glm.GLM( + regularizer=self.cls(), solver_name=solver_name, regularizer_strength=1 + ) @pytest.mark.parametrize( "solver_name", @@ -780,14 +796,14 @@ def test_init_solver_kwargs(self, solver_kwargs, solver_name): regularizer=regularizer, solver_name=solver_name, solver_kwargs=solver_kwargs, - regularizer_strength=1. + regularizer_strength=1.0, ) else: nmo.glm.GLM( regularizer=regularizer, solver_name=solver_name, solver_kwargs=solver_kwargs, - regularizer_strength=1. + regularizer_strength=1.0, ) def test_regularizer_strength_none(self): @@ -892,7 +908,7 @@ def test_solver_match_statsmodels( def test_lasso_pytree(self, poissonGLM_model_instantiation_pytree): """Check pytree X can be fit.""" X, y, model, true_params, firing_rate = poissonGLM_model_instantiation_pytree - model.set_params(regularizer=nmo.regularizer.Lasso(), regularizer_strength=1.) + model.set_params(regularizer=nmo.regularizer.Lasso(), regularizer_strength=1.0) model.solver_name = "ProximalGradient" model.fit(X, y) @@ -910,9 +926,12 @@ def test_lasso_pytree_match( X, _, model, _, _ = poissonGLM_model_instantiation_pytree X_array, y, model_array, _, _ = poissonGLM_model_instantiation - - model.set_params(regularizer=nmo.regularizer.Lasso(), regularizer_strength=reg_str) - model_array.set_params(regularizer=nmo.regularizer.Lasso(), regularizer_strength=reg_str) + model.set_params( + regularizer=nmo.regularizer.Lasso(), regularizer_strength=reg_str + ) + model_array.set_params( + regularizer=nmo.regularizer.Lasso(), regularizer_strength=reg_str + ) model.solver_name = solver_name model_array.solver_name = solver_name model.fit(X, y) @@ -924,7 +943,7 @@ def test_lasso_pytree_match( @pytest.mark.parametrize("solver_name", ["ProximalGradient", "ProxSVRG"]) def test_solver_combination(self, solver_name, poissonGLM_model_instantiation): X, y, model, true_params, firing_rate = poissonGLM_model_instantiation - model.set_params(regularizer=self.cls(), regularizer_strength=1.) + model.set_params(regularizer=self.cls(), regularizer_strength=1.0) model.solver_name = solver_name model.fit(X, y) @@ -962,9 +981,17 @@ def test_init_solver_name(self, solver_name): with pytest.raises( ValueError, match=f"The solver: {solver_name} is not allowed for " ): - nmo.glm.GLM(regularizer=self.cls(mask=mask), solver_name=solver_name, regularizer_strength=1) + nmo.glm.GLM( + regularizer=self.cls(mask=mask), + solver_name=solver_name, + regularizer_strength=1, + ) else: - nmo.glm.GLM(regularizer=self.cls(mask=mask), solver_name=solver_name, regularizer_strength=1) + nmo.glm.GLM( + regularizer=self.cls(mask=mask), + solver_name=solver_name, + regularizer_strength=1, + ) @pytest.mark.parametrize( "solver_name", @@ -1022,14 +1049,14 @@ def test_init_solver_kwargs(self, solver_name, solver_kwargs): regularizer=regularizer, solver_name=solver_name, solver_kwargs=solver_kwargs, - regularizer_strength=1. + regularizer_strength=1.0, ) else: nmo.glm.GLM( regularizer=regularizer, solver_name=solver_name, solver_kwargs=solver_kwargs, - regularizer_strength=1. + regularizer_strength=1.0, ) def test_regularizer_strength_none(self): @@ -1068,7 +1095,7 @@ def test_loss_callable(self, loss): mask = jnp.asarray(mask) regularizer = self.cls(mask=mask) - model = nmo.glm.GLM(regularizer=regularizer, regularizer_strength=1.) + model = nmo.glm.GLM(regularizer=regularizer, regularizer_strength=1.0) model._predict_and_compute_loss = loss if raise_exception: @@ -1089,7 +1116,7 @@ def test_run_solver(self, solver_name, poissonGLM_model_instantiation): mask[1, 2:] = 1 mask = jnp.asarray(mask) - model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.0) model.solver_name = solver_name model.instantiate_solver() @@ -1107,7 +1134,7 @@ def test_init_solver(self, solver_name, poissonGLM_model_instantiation): mask[1, 2:] = 1 mask = jnp.asarray(mask) - model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.0) model.solver_name = solver_name model.instantiate_solver() @@ -1133,7 +1160,7 @@ def test_update_solver(self, solver_name, poissonGLM_model_instantiation): mask[1, 2:] = 1 mask = jnp.asarray(mask) - model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.0) model.solver_name = solver_name model.instantiate_solver() @@ -1155,7 +1182,9 @@ def test_update_solver(self, solver_name, poissonGLM_model_instantiation): and hasattr(state, "_asdict") ) # check params struct and shapes - assert jax.tree_util.tree_structure(params) == jax.tree_util.tree_structure(true_params) + assert jax.tree_util.tree_structure(params) == jax.tree_util.tree_structure( + true_params + ) assert all( jax.tree_util.tree_leaves(params)[k].shape == p.shape for k, p in enumerate(jax.tree_util.tree_leaves(true_params)) @@ -1193,9 +1222,11 @@ def test_mask_validity_groups( with pytest.raises( ValueError, match="Incorrect group assignment. " "Some of the features" ): - model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) + model.set_params( + regularizer=self.cls(mask=mask), regularizer_strength=1.0 + ) else: - model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.0) @pytest.mark.parametrize("set_entry", [0, 1, -1, 2, 2.5]) def test_mask_validity_entries(self, set_entry, poissonGLM_model_instantiation): @@ -1213,9 +1244,11 @@ def test_mask_validity_entries(self, set_entry, poissonGLM_model_instantiation): if raise_exception: with pytest.raises(ValueError, match="Mask elements be 0s and 1s"): - model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) + model.set_params( + regularizer=self.cls(mask=mask), regularizer_strength=1.0 + ) else: - model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.0) @pytest.mark.parametrize("n_dim", [0, 1, 2, 3]) def test_mask_dimension_1(self, n_dim, poissonGLM_model_instantiation): @@ -1242,9 +1275,11 @@ def test_mask_dimension_1(self, n_dim, poissonGLM_model_instantiation): if raise_exception: with pytest.raises(ValueError, match="`mask` must be 2-dimensional"): - model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) + model.set_params( + regularizer=self.cls(mask=mask), regularizer_strength=1.0 + ) else: - model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.0) @pytest.mark.parametrize("n_groups", [0, 1, 2]) def test_mask_n_groups(self, n_groups, poissonGLM_model_instantiation): @@ -1263,9 +1298,11 @@ def test_mask_n_groups(self, n_groups, poissonGLM_model_instantiation): if raise_exception: with pytest.raises(ValueError, match=r"Empty mask provided! Mask has "): - model.set_params(regularizer = self.cls(mask=mask), regularizer_strength=1.) + model.set_params( + regularizer=self.cls(mask=mask), regularizer_strength=1.0 + ) else: - model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.0) def test_group_sparsity_enforcement( self, group_sparse_poisson_glm_model_instantiation @@ -1285,7 +1322,7 @@ def test_group_sparsity_enforcement( mask[1, ~zeros_true] = 1 mask = jnp.asarray(mask, dtype=jnp.float32) - model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.0) model.solver_name = "ProximalGradient" runner = model.instantiate_solver().solver_run @@ -1429,8 +1466,12 @@ def test_mask_none(self, poissonGLM_model_instantiation): @pytest.mark.parametrize("solver_name", ["ProximalGradient", "ProxSVRG"]) def test_solver_combination(self, solver_name, poissonGLM_model_instantiation): X, y, model, true_params, firing_rate = poissonGLM_model_instantiation - model.set_params(regularizer=self.cls(mask=np.ones((1, X.shape[1])).astype(float)), - regularizer_strength=None if self.cls==nmo.regularizer.UnRegularized else 1.) + model.set_params( + regularizer=self.cls(mask=np.ones((1, X.shape[1])).astype(float)), + regularizer_strength=( + None if self.cls == nmo.regularizer.UnRegularized else 1.0 + ), + ) model.solver_name = solver_name model.fit(X, y) diff --git a/tests/test_solvers.py b/tests/test_solvers.py index 806f6a4a..6d8f6a7b 100644 --- a/tests/test_solvers.py +++ b/tests/test_solvers.py @@ -128,7 +128,11 @@ def test_svrg_glm_instantiate_solver(regularizer_name, solver_class, mask): if mask is not None: kwargs["mask"] = mask - glm = nmo.glm.GLM(regularizer=regularizer_name, solver_name=solver_name, regularizer_strength=None if regularizer_name == "UnRegularized" else 1,) + glm = nmo.glm.GLM( + regularizer=regularizer_name, + solver_name=solver_name, + regularizer_strength=None if regularizer_name == "UnRegularized" else 1, + ) glm.instantiate_solver() solver = inspect.getclosurevars(glm._solver_run).nonlocals["solver"] @@ -178,9 +182,9 @@ def test_svrg_glm_passes_solver_kwargs(regularizer_name, solver_name, mask, glm_ ( "GroupLasso", ProxSVRG, - np.array([[0.], [0.], [1.]]), + np.array([[0.0], [0.0], [1.0]]), ), - ("GroupLasso", ProxSVRG, np.array([[1.], [0.], [0.]])), + ("GroupLasso", ProxSVRG, np.array([[1.0], [0.0], [0.0]])), ("Ridge", SVRG, None), ("UnRegularized", SVRG, None), ], @@ -233,7 +237,7 @@ def test_svrg_glm_initialize_state( ( "GroupLasso", ProxSVRG, - np.array([[0.], [0.], [1.]]), + np.array([[0.0], [0.0], [1.0]]), ), ("Ridge", SVRG, None), ("UnRegularized", SVRG, None), @@ -345,7 +349,7 @@ def test_svrg_glm_fit( observation_model=nmo.observation_models.PoissonObservations(jax.nn.softplus), solver_kwargs=solver_kwargs, regularizer_strength=None if regularizer_name == "UnRegularized" else 1, - **kwargs + **kwargs, ) if isinstance(glm, nmo.glm.PopulationGLM): diff --git a/tests/test_svrg_defaults.py b/tests/test_svrg_defaults.py index ff6a117e..cc8a3d3e 100644 --- a/tests/test_svrg_defaults.py +++ b/tests/test_svrg_defaults.py @@ -1,8 +1,9 @@ -import pytest +from contextlib import nullcontext as does_not_raise + import jax.numpy as jnp +import pytest from nemos.solvers import _svrg_defaults -from contextlib import nullcontext as does_not_raise @pytest.fixture @@ -197,27 +198,51 @@ def test_warnigns_svrg_optimal_batch_and_stepsize( @pytest.mark.parametrize( "n_power_iter, expectation", [ - (None, pytest.warns(UserWarning, match="Direct computation of the eigenvalues")), + ( + None, + pytest.warns(UserWarning, match="Direct computation of the eigenvalues"), + ), (1, does_not_raise()), (10, does_not_raise()), - ("a", pytest.raises(TypeError, match="`n_power_iters` must be an integer or None")), - (0.5, pytest.raises(TypeError, match="`n_power_iters` must be an integer or None")), - (-1, pytest.raises(ValueError, match="`n_power_iters` must be positive")) - ] + ( + "a", + pytest.raises( + TypeError, match="`n_power_iters` must be an integer or None" + ), + ), + ( + 0.5, + pytest.raises( + TypeError, match="`n_power_iters` must be an integer or None" + ), + ), + (-1, pytest.raises(ValueError, match="`n_power_iters` must be positive")), + ], ) -def test_glm_softplus_poisson_l_smooth_power_iter(x_sample, y_sample, n_power_iter, expectation): +def test_glm_softplus_poisson_l_smooth_power_iter( + x_sample, y_sample, n_power_iter, expectation +): with expectation: - _svrg_defaults._glm_softplus_poisson_l_smooth(x_sample, y_sample, batch_size=1, n_power_iters=n_power_iter) + _svrg_defaults._glm_softplus_poisson_l_smooth( + x_sample, y_sample, batch_size=1, n_power_iters=n_power_iter + ) @pytest.mark.parametrize( "delta_num_sample, expectation", [ (0, does_not_raise()), - (1, pytest.raises(ValueError, match="Each array in data must have the same number")) - ] - ) -def test_svrg_optimal_batch_and_stepsize_num_samples(x_sample, y_sample, delta_num_sample, expectation): + ( + 1, + pytest.raises( + ValueError, match="Each array in data must have the same number" + ), + ), + ], +) +def test_svrg_optimal_batch_and_stepsize_num_samples( + x_sample, y_sample, delta_num_sample, expectation +): y_sample = y_sample[delta_num_sample:] with expectation: _svrg_defaults.svrg_optimal_batch_and_stepsize( @@ -229,30 +254,71 @@ def test_svrg_optimal_batch_and_stepsize_num_samples(x_sample, y_sample, delta_n strong_convexity=0.1, ) + @pytest.mark.parametrize( "num_samples, l_smooth_max, l_smooth, strong_convexity, expected_batch_size", [ # Case 1: strong_convexity is None - (100, 10.0, 2.0, None, 1), # strong_convexity is None, should return batch_size = 1 + ( + 100, + 10.0, + 2.0, + None, + 1, + ), # strong_convexity is None, should return batch_size = 1 # Case 2: num_samples >= 3 * l_smooth_max / strong_convexity - (100, 10.0, 2.0, 0.8, 1), # num_samples >= 3 * l_smooth_max / strong_convexity, should return batch_size = 1 + ( + 100, + 10.0, + 2.0, + 0.8, + 1, + ), # num_samples >= 3 * l_smooth_max / strong_convexity, should return batch_size = 1 # Case 3: num_samples > l_smooth / strong_convexity - (100, 10.0, 2.0, 0.1, 2), # num_samples > l_smooth / strong_convexity, and b_tilde is the minimum + ( + 100, + 10.0, + 2.0, + 0.1, + 2, + ), # num_samples > l_smooth / strong_convexity, and b_tilde is the minimum # Case 4: l_smooth_max < num_samples * l_smooth / 3 and b_hat < b_tilde - (100, 5.0, 0.2, 0.1, 1), # l_smooth_max < num_samples * l_smooth / 3, use minimum(b_hat, b_tilde) + ( + 100, + 5.0, + 0.2, + 0.1, + 1, + ), # l_smooth_max < num_samples * l_smooth / 3, use minimum(b_hat, b_tilde) # Case 5: l_smooth_max >= num_samples * l_smooth / 3 - (100, 10.0, 0.2, 0.01, 27), # l_smooth_max >= num_samples * l_smooth / 3, batch_size = num_samples + ( + 100, + 10.0, + 0.2, + 0.01, + 27, + ), # l_smooth_max >= num_samples * l_smooth / 3, batch_size = num_samples # Case 6: l_smooth_max >= num_samples * l_smooth / 3, but falls back to b_tilde - (100, 5.0, 0.05, 0.1, 1), # l_smooth_max >= num_samples * l_smooth / 3, but falls back to b_tilde + ( + 100, + 5.0, + 0.05, + 0.1, + 1, + ), # l_smooth_max >= num_samples * l_smooth / 3, but falls back to b_tilde # Case 7: l_smooth_max < num_samples * l_smooth / 3 (100, 5.0, 0.5, 0.005, 4), # Case 8: l_smooth_max > num_samples * l_smooth / 3 (100, 18.0, 0.5, 0.005, 100), - ] + ], ) -def test_calculate_optimal_batch_size_svrg_all_config(num_samples, l_smooth_max, l_smooth, strong_convexity, expected_batch_size): +def test_calculate_optimal_batch_size_svrg_all_config( + num_samples, l_smooth_max, l_smooth, strong_convexity, expected_batch_size +): """Test the calculation of the optimal batch size for SVRG.""" batch_size = _svrg_defaults._calculate_optimal_batch_size_svrg( num_samples, l_smooth_max, l_smooth, strong_convexity ) - assert batch_size == expected_batch_size, f"Expected batch_size {expected_batch_size}, got {batch_size}" \ No newline at end of file + assert ( + batch_size == expected_batch_size + ), f"Expected batch_size {expected_batch_size}, got {batch_size}" diff --git a/tests/test_tree_utils.py b/tests/test_tree_utils.py index 33c58850..19070c40 100644 --- a/tests/test_tree_utils.py +++ b/tests/test_tree_utils.py @@ -107,18 +107,23 @@ def test_get_valid_multitree(trees, expected): assert jnp.array_equal(tree_utils.get_valid_multitree(*trees), expected) -@pytest.mark.parametrize("idx", [ - slice(2, 5), # Slice indexing - np.array([1, 3, 5]), # Integer list indexing - np.array([True, False, True, False, True, False, True, False, True, False]), # Boolean array indexing - (slice(1, 3), slice(0, 2)) # Mixed indexing (simple example with slices) -]) +@pytest.mark.parametrize( + "idx", + [ + slice(2, 5), # Slice indexing + np.array([1, 3, 5]), # Integer list indexing + np.array( + [True, False, True, False, True, False, True, False, True, False] + ), # Boolean array indexing + (slice(1, 3), slice(0, 2)), # Mixed indexing (simple example with slices) + ], +) def test_tree_slice(idx): mydict = { - 'array1': np.random.rand(10, 3), - 'array2': np.random.rand(10, 2), - 'array3': np.random.rand(10, 4), - 'array4': jnp.arange(30).reshape(10, 3) + "array1": np.random.rand(10, 3), + "array2": np.random.rand(10, 2), + "array3": np.random.rand(10, 4), + "array4": jnp.arange(30).reshape(10, 3), } result = tree_utils.tree_slice(mydict, idx) for key in mydict: diff --git a/tests/test_type_casting.py b/tests/test_type_casting.py index b7cbbe48..e9e7c6c7 100644 --- a/tests/test_type_casting.py +++ b/tests/test_type_casting.py @@ -370,7 +370,9 @@ def func(*x): ( [ nap.Tsd(t=np.arange(10), d=np.arange(10)), - nap.Tsd(t=np.arange(1), d=np.arange(1), time_support=nap.IntervalSet(0, 10)), + nap.Tsd( + t=np.arange(1), d=np.arange(1), time_support=nap.IntervalSet(0, 10) + ), nap.Tsd(t=np.arange(10), d=np.arange(10)), ], pytest.raises( diff --git a/tests/test_utils.py b/tests/test_utils.py index abe46314..3dc93236 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -109,8 +109,11 @@ def test_conv_type(self, iterable, predictor_causality): utils.nan_pad(iterable, 3, predictor_causality) else: with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=UserWarning, - message="With acausal filter, pad_size should probably be even") + warnings.filterwarnings( + "ignore", + category=UserWarning, + message="With acausal filter, pad_size should probably be even", + ) utils.nan_pad(iterable, 3, predictor_causality) @pytest.mark.parametrize("iterable", [[np.zeros([2, 4, 5]), np.zeros([2, 4, 6])]]) @@ -175,8 +178,11 @@ def test_padding_nan_acausal(self, pad_size, iterable): else: init_nan, end_nan = pad_size // 2, pad_size - pad_size // 2 with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=UserWarning, - message="With acausal filter, pad_size should probably be even") + warnings.filterwarnings( + "ignore", + category=UserWarning, + message="With acausal filter, pad_size should probably be even", + ) padded = utils.nan_pad(iterable, pad_size, "acausal") for trial in padded: print(trial.shape, pad_size) @@ -260,8 +266,11 @@ def test_nan_pad_conv_dtype(self, dtype, expectation): ) def test_axis_compatibility(self, pad_size, array, causality, axis, expectation): with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=UserWarning, - message="With acausal filter, pad_size should probably be even") + warnings.filterwarnings( + "ignore", + category=UserWarning, + message="With acausal filter, pad_size should probably be even", + ) with expectation: utils.nan_pad(array, pad_size, causality, axis=axis) @@ -284,8 +293,11 @@ def test_axis_compatibility(self, pad_size, array, causality, axis, expectation) @pytest.mark.parametrize("array", [jnp.zeros((10,)), np.zeros((10, 11))]) def test_pad_size_type(self, pad_size, array, causality, expectation): with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=UserWarning, - message="With acausal filter, pad_size should probably be even") + warnings.filterwarnings( + "ignore", + category=UserWarning, + message="With acausal filter, pad_size should probably be even", + ) with expectation: utils.nan_pad(array, pad_size, causality, axis=0) diff --git a/tox.ini b/tox.ini index 68b4e36d..86557fb2 100644 --- a/tox.ini +++ b/tox.ini @@ -23,6 +23,8 @@ commands= isort docs/how_to_guide --profile=black isort docs/background --profile=black isort docs/tutorials --profile=black + black tests + isort tests --profile=black [testenv:check] commands=