diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 8e212cbf..44b024db 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -307,11 +307,7 @@ def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor): structure: a single (X, y) pair for the transformer, a number of time series for the Basis. """ - if self.kernel_ is None: - raise ValueError( - "You must call `setup_basis` before `_compute_features`! " - "Convolution kernel is not set." - ) + self._check_has_kernel() # before calling the convolve, check that the input matches # the expectation. We can check xi[0] only, since convolution # is applied at the end of the recursion on the 1D basis, ensuring len(xi) == 1. @@ -457,8 +453,8 @@ def _check_convolution_kwargs(conv_kwargs: dict): def _check_has_kernel(self) -> None: """Check that the kernel is pre-computed.""" if self.kernel_ is None: - raise ValueError( - "You must call `_set_kernel` before `_compute_features` for Conv basis." + raise RuntimeError( + "You must call `setup_basis` before `_compute_features` for Conv basis." ) @@ -517,7 +513,7 @@ def __init__(self, basis1: Basis, basis2: Basis): *(bas2._input_shape_ for bas2 in basis2._iterate_over_components()), ) # if all bases where set, then set input for composition. - set_bases = (s is not None for s in shapes) + set_bases = [s is not None for s in shapes] if all(set_bases): # pass down the input shapes diff --git a/tests/test_basis.py b/tests/test_basis.py index 5168cc44..0f69850e 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -152,6 +152,15 @@ def method(self): pass assert CustomSubClass().method.__doc__ == "My extra text.\nMy custom method." + with pytest.raises(AttributeError, match="CustomClass has no attribute"): + + class CustomSubClass2(CustomClass): + @custom_add_docstring("unknown", cls=CustomClass) + def method(self): + """My custom method.""" + pass + + CustomSubClass2() @pytest.mark.parametrize( @@ -241,18 +250,34 @@ def test_expected_output_compute_features(basis_instance, super_class): ), OrthExponentialBasis, ), + ( + basis.OrthExponentialConv( + 10, decay_rates=np.arange(1, 11), window_size=12, label="a" + ) + * basis.RaisedCosineLogConv(10, window_size=11, label="b"), + OrthExponentialBasis, + ), + ( + basis.OrthExponentialConv( + 10, decay_rates=np.arange(1, 11), window_size=12, label="a" + ) + + basis.RaisedCosineLogConv(10, window_size=11, label="b"), + OrthExponentialBasis, + ), ], ) def test_expected_output_split_by_feature(basis_instance, super_class): - x = super_class.compute_features(basis_instance, np.linspace(0, 1, 100)) + inp = [np.linspace(0, 1, 100)] * basis_instance._n_input_dimensionality + x = super_class.compute_features(basis_instance, *inp) xdict = super_class.split_by_feature(basis_instance, x) xxdict = basis_instance.split_by_feature(x) assert xdict.keys() == xxdict.keys() - xx = xxdict["label"] - x = xdict["label"] - nans = np.isnan(x.sum(axis=(1, 2))) - assert np.all(np.isnan(xx[nans])) - np.testing.assert_array_equal(xx[~nans], x[~nans]) + for k in xdict.keys(): + xx = xxdict[k] + x = xdict[k] + nans = np.isnan(x.sum(axis=(1, 2))) + assert np.all(np.isnan(xx[nans])) + np.testing.assert_array_equal(xx[~nans], x[~nans]) @pytest.mark.parametrize( @@ -1236,7 +1261,7 @@ def test_transform_fails(self, cls): n_basis_funcs=5, window_size=5, **extra_decay_rates(cls["conv"], 5) ) with pytest.raises( - ValueError, match="You must call `setup_basis` before `_compute_features`" + RuntimeError, match="You must call `setup_basis` before `_compute_features`" ): bas._compute_features(np.linspace(0, 1, 10)) @@ -1579,6 +1604,15 @@ def test_minimum_number_of_basis_required_is_matched( n_basis_funcs=n_basis_funcs, order=order, **kwargs ) basis_obj.compute_features(np.linspace(0, 1, 10)) + + # test the setter valuerror + if (order > 1) & (n_basis_funcs > 1): + basis_obj = self.cls[mode](n_basis_funcs=20, order=order, **kwargs) + with pytest.raises( + ValueError, + match=rf"{self.cls[mode].__name__} `order` parameter cannot be larger than", + ): + basis_obj.n_basis_funcs = n_basis_funcs else: basis_obj = self.cls[mode]( n_basis_funcs=n_basis_funcs, order=order, **kwargs @@ -2210,6 +2244,41 @@ def test_number_of_required_inputs_compute_features( with expectation: basis_obj.compute_features(*inputs) + @pytest.mark.parametrize("basis_a", list_all_basis_classes()) + @pytest.mark.parametrize("basis_b", list_all_basis_classes()) + @pytest.mark.parametrize("n_basis_a", [5]) + @pytest.mark.parametrize("n_basis_b", [6]) + @pytest.mark.parametrize("window_size", [10]) + def test_warn_partial_setup( + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, + ): + basis_a_obj = self.instantiate_basis( + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size + ) + basis_b_obj = self.instantiate_basis( + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size + ) + + basis_a_obj.set_input_shape(*([1] * basis_a_obj._n_input_dimensionality)) + with pytest.warns(UserWarning, match="Only some of the basis where"): + basis_a_obj + basis_b_obj + + # check that if both set addition is fine + basis_b_obj.set_input_shape(*([1] * basis_b_obj._n_input_dimensionality)) + basis_a_obj + basis_b_obj + + basis_a_obj = self.instantiate_basis( + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size + ) + with pytest.warns(UserWarning, match="Only some of the basis where"): + basis_a_obj + basis_b_obj + @pytest.mark.parametrize("sample_size", [11, 20]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @@ -2661,7 +2730,7 @@ def test_transform_fails( context = does_not_raise() else: context = pytest.raises( - ValueError, + RuntimeError, match="You must call `setup_basis` before `_compute_features`", ) with context: @@ -3589,7 +3658,7 @@ def test_transform_fails( context = does_not_raise() else: context = pytest.raises( - ValueError, + RuntimeError, match="You must call `setup_basis` before `_compute_features`", ) with context: