diff --git a/tests/test_basis.py b/tests/test_basis.py index 2ef6cca1..de887638 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -7119,3 +7119,62 @@ def test_duplicate_keys(bas1, bas2, bas3): ) slice_dict = bas_obj._get_feature_slicing()[0] assert tuple(slice_dict.keys()) == ("label", "label-1", "label-2") + + +@pytest.mark.parametrize( + "bas1, bas2", + list( + itertools.product( + *[tuple((getattr(basis, basis_name) for basis_name in dir(basis)))] * 2 + ) + ), +) +@pytest.mark.parametrize( + "x, axis, expectation, exp_shapes", # num output is 5*2 + 6*3 = 28 + [ + (np.ones((1, 28)), 1, does_not_raise(), [(1, 5), (1,6)]), + (np.ones((28, )), 0, does_not_raise(), [(5,), (6,)]), + (np.ones((2, 2, 28)), 2, does_not_raise(), [(2, 2, 5), (2, 2, 6)]), + (np.ones((2, 2, 27)), 2, pytest.raises(ValueError, match=r"`x.shape\[axis\]` does not match the expected"), [(2, 2, 5), (2, 2, 6)]), + ] +) +def test_split_feature_axis(bas1, bas2, x, axis, expectation, exp_shapes): + # skip nested + if any( + bas in (basis.AdditiveBasis, basis.MultiplicativeBasis, basis.TransformerBasis) + for bas in [bas1, bas2] + ): + return + # define the basis + n_basis = [5, 6] + mode = "conv" + extra_kwargs = ( + {"decay_rates": np.arange(1, n_basis[0] + 1), "window_size": 5}, + {"decay_rates": np.arange(1, n_basis[1] + 1), "window_size": 5}, + ) + for i, val in enumerate(zip([bas1, bas2], extra_kwargs)): + bas, kwrgs = val + if bas != basis.OrthExponentialBasis: + kwrgs.pop("decay_rates") + + bas1_instance = bas1( + n_basis[0], + mode=mode, + n_basis_input=2, + **extra_kwargs[0], + label="1", + ) + bas2_instance = bas2( + n_basis[1], + mode=mode, + n_basis_input=3, + **extra_kwargs[1], + label="2", + ) + bas = bas1_instance + bas2_instance + with expectation: + out = bas.split_feature_axis(x, axis=axis) + basis_list = [bas1_instance, bas2_instance] + for i, itm in enumerate(out.items()): + key, val = itm + assert all(v.shape == exp_shapes[i] for v in val.values()) \ No newline at end of file