Skip to content

Commit

Permalink
generalized tests
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Dec 2, 2024
1 parent 6a0574d commit b7a7b60
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 26 deletions.
32 changes: 7 additions & 25 deletions tests/test_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,6 @@
from nemos.utils import pynapple_concatenate_numpy


@pytest.fixture()
def class_specific_params():
shared_params = ["n_basis_funcs", "label"]
eval_params = ["bounds"]
conv_params = ["window_size", "conv_kwargs"]
return dict(
EvalBSpline=shared_params + eval_params + ["order"],
ConvBSpline=shared_params + conv_params + ["order"],
EvalMSpline=shared_params + eval_params + ["order"],
ConvMSpline=shared_params + conv_params + ["order"],
EvalCyclicBSpline=shared_params + eval_params + ["order"],
ConvCyclicBSpline=shared_params + conv_params + ["order"],
EvalRaisedCosineLinear=shared_params + eval_params + ["width"],
ConvRaisedCosineLinear=shared_params + conv_params + ["width"],
EvalRaisedCosineLog=shared_params
+ eval_params
+ ["width", "time_scaling", "enforce_decay_to_zero"],
ConvRaisedCosineLog=shared_params
+ conv_params
+ ["width", "time_scaling", "enforce_decay_to_zero"],
EvalOrthExponential=shared_params + eval_params + ["decay_rates"],
ConvOrthExponential=shared_params + conv_params + ["decay_rates"],
)


def trim_kwargs(cls, kwargs, class_specific_params):
return {
key: value
Expand Down Expand Up @@ -87,6 +62,13 @@ def list_all_basis_classes(filter_basis="all") -> list[type]:
return all_basis


@pytest.fixture()
def class_specific_params():
"""Returns all the params for each class."""
all_cls = list_all_basis_classes("Conv") + list_all_basis_classes("Eval")
return {cls.__name__: cls._get_param_names() for cls in all_cls}


def test_all_basis_are_tested() -> None:
"""Meta-test.
Expand Down
4 changes: 3 additions & 1 deletion tests/test_identifiability_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,9 @@ def test_apply_constraint_by_basis_with_invalid(invalid_entries):
)
# add invalid
x[:2, 2] = invalid_entries
constrained_x, kept_cols = apply_identifiability_constraints(x, warn_if_float32=False)
constrained_x, kept_cols = apply_identifiability_constraints(
x, warn_if_float32=False
)
assert jnp.array_equal(kept_cols, jnp.arange(1, 5))
assert constrained_x.shape[0] == x.shape[0]
assert jnp.all(jnp.isnan(constrained_x[:2]))

0 comments on commit b7a7b60

Please sign in to comment.