Skip to content

Commit

Permalink
added test
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Oct 11, 2024
1 parent 8f1b832 commit 1397b6e
Showing 1 changed file with 59 additions and 0 deletions.
59 changes: 59 additions & 0 deletions tests/test_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7119,3 +7119,62 @@ def test_duplicate_keys(bas1, bas2, bas3):
)
slice_dict = bas_obj._get_feature_slicing()[0]
assert tuple(slice_dict.keys()) == ("label", "label-1", "label-2")


@pytest.mark.parametrize(
"bas1, bas2",
list(
itertools.product(
*[tuple((getattr(basis, basis_name) for basis_name in dir(basis)))] * 2
)
),
)
@pytest.mark.parametrize(
"x, axis, expectation, exp_shapes", # num output is 5*2 + 6*3 = 28
[
(np.ones((1, 28)), 1, does_not_raise(), [(1, 5), (1,6)]),
(np.ones((28, )), 0, does_not_raise(), [(5,), (6,)]),
(np.ones((2, 2, 28)), 2, does_not_raise(), [(2, 2, 5), (2, 2, 6)]),
(np.ones((2, 2, 27)), 2, pytest.raises(ValueError, match=r"`x.shape\[axis\]` does not match the expected"), [(2, 2, 5), (2, 2, 6)]),
]
)
def test_split_feature_axis(bas1, bas2, x, axis, expectation, exp_shapes):
# skip nested
if any(
bas in (basis.AdditiveBasis, basis.MultiplicativeBasis, basis.TransformerBasis)
for bas in [bas1, bas2]
):
return
# define the basis
n_basis = [5, 6]
mode = "conv"
extra_kwargs = (
{"decay_rates": np.arange(1, n_basis[0] + 1), "window_size": 5},
{"decay_rates": np.arange(1, n_basis[1] + 1), "window_size": 5},
)
for i, val in enumerate(zip([bas1, bas2], extra_kwargs)):
bas, kwrgs = val
if bas != basis.OrthExponentialBasis:
kwrgs.pop("decay_rates")

bas1_instance = bas1(
n_basis[0],
mode=mode,
n_basis_input=2,
**extra_kwargs[0],
label="1",
)
bas2_instance = bas2(
n_basis[1],
mode=mode,
n_basis_input=3,
**extra_kwargs[1],
label="2",
)
bas = bas1_instance + bas2_instance
with expectation:
out = bas.split_feature_axis(x, axis=axis)
basis_list = [bas1_instance, bas2_instance]
for i, itm in enumerate(out.items()):
key, val = itm
assert all(v.shape == exp_shapes[i] for v in val.values())

0 comments on commit 1397b6e

Please sign in to comment.