diff --git a/tests/test_basis.py b/tests/test_basis.py index c8412897..3f86ebe7 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -27,31 +27,6 @@ from nemos.utils import pynapple_concatenate_numpy -@pytest.fixture() -def class_specific_params(): - shared_params = ["n_basis_funcs", "label"] - eval_params = ["bounds"] - conv_params = ["window_size", "conv_kwargs"] - return dict( - EvalBSpline=shared_params + eval_params + ["order"], - ConvBSpline=shared_params + conv_params + ["order"], - EvalMSpline=shared_params + eval_params + ["order"], - ConvMSpline=shared_params + conv_params + ["order"], - EvalCyclicBSpline=shared_params + eval_params + ["order"], - ConvCyclicBSpline=shared_params + conv_params + ["order"], - EvalRaisedCosineLinear=shared_params + eval_params + ["width"], - ConvRaisedCosineLinear=shared_params + conv_params + ["width"], - EvalRaisedCosineLog=shared_params - + eval_params - + ["width", "time_scaling", "enforce_decay_to_zero"], - ConvRaisedCosineLog=shared_params - + conv_params - + ["width", "time_scaling", "enforce_decay_to_zero"], - EvalOrthExponential=shared_params + eval_params + ["decay_rates"], - ConvOrthExponential=shared_params + conv_params + ["decay_rates"], - ) - - def trim_kwargs(cls, kwargs, class_specific_params): return { key: value @@ -87,6 +62,13 @@ def list_all_basis_classes(filter_basis="all") -> list[type]: return all_basis +@pytest.fixture() +def class_specific_params(): + """Returns all the params for each class.""" + all_cls = list_all_basis_classes("Conv") + list_all_basis_classes("Eval") + return {cls.__name__: cls._get_param_names() for cls in all_cls} + + def test_all_basis_are_tested() -> None: """Meta-test. diff --git a/tests/test_identifiability_constraints.py b/tests/test_identifiability_constraints.py index ace846ea..0fda51e9 100644 --- a/tests/test_identifiability_constraints.py +++ b/tests/test_identifiability_constraints.py @@ -215,7 +215,9 @@ def test_apply_constraint_by_basis_with_invalid(invalid_entries): ) # add invalid x[:2, 2] = invalid_entries - constrained_x, kept_cols = apply_identifiability_constraints(x, warn_if_float32=False) + constrained_x, kept_cols = apply_identifiability_constraints( + x, warn_if_float32=False + ) assert jnp.array_equal(kept_cols, jnp.arange(1, 5)) assert constrained_x.shape[0] == x.shape[0] assert jnp.all(jnp.isnan(constrained_x[:2]))