diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index 741b93b7..418c2580 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -211,8 +211,8 @@ def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: >>> # Transform basis >>> feature_transformed = transformer.transform(X) """ - self._check_input(X, y) self._check_initialized(self._basis) + self._check_input(X, y) # transpose does not work with pynapple # can't use func(*X.T) to unwrap return self._basis._compute_features(*self._unpack_inputs(X)) diff --git a/tests/test_transformer_basis.py b/tests/test_transformer_basis.py index a99d785a..79fa1848 100644 --- a/tests/test_transformer_basis.py +++ b/tests/test_transformer_basis.py @@ -3,15 +3,14 @@ import numpy as np import pytest - from conftest import CombinedBasis, list_all_basis_classes from sklearn.base import clone as sk_clone from sklearn.pipeline import Pipeline import nemos as nmo from nemos import basis -from nemos._inspect_utils import list_abstract_methods, get_subclass_methods -from nemos.basis import AdditiveBasis, MultiplicativeBasis +from nemos._inspect_utils import get_subclass_methods, list_abstract_methods +from nemos.basis import AdditiveBasis, MSplineConv, MultiplicativeBasis @pytest.mark.parametrize( @@ -699,12 +698,13 @@ def test_transformer_fit_transform_input_struct( @pytest.mark.parametrize( "inp", [ - np.random.randn( + 0.1 + * np.random.randn( 100, ), - np.random.randn(100, 1), - np.random.randn(100, 2), - np.random.randn(100, 1, 2), + 0.1 * np.random.randn(100, 1), + 0.1 * np.random.randn(100, 2), + 0.1 * np.random.randn(100, 1, 2), ], ) def test_transformer_in_pipeline(basis_cls, inp, basis_class_specific_params): @@ -756,6 +756,7 @@ def test_transformer_in_pipeline(basis_cls, inp, basis_class_specific_params): model.fit(X, y) np.testing.assert_allclose(pipe["glm"].coef_, model.coef_) + @pytest.mark.parametrize( "basis_cls", list_all_basis_classes(), @@ -766,7 +767,7 @@ def test_initialization(basis_cls, basis_class_specific_params): ) transformer = bas.to_transformer() with pytest.raises(RuntimeError, match="Cannot apply TransformerBasis"): - transformer.fit(np.ones((100, ))) + transformer.fit(np.ones((100,))) with pytest.raises(RuntimeError, match="Cannot apply TransformerBasis"): transformer.transform(np.ones((100,))) @@ -784,7 +785,9 @@ def test_basis_setter(basis_cls, basis_class_specific_params): 5, basis_cls, basis_class_specific_params, window_size=10 ) - bas2 = CombinedBasis().instantiate_basis(7, basis_cls, basis_class_specific_params, window_size=10) + bas2 = CombinedBasis().instantiate_basis( + 7, basis_cls, basis_class_specific_params, window_size=10 + ) transformer = bas.to_transformer() transformer.basis = bas2 assert transformer.basis.n_basis_funcs == bas2.n_basis_funcs @@ -811,17 +814,20 @@ def test_eetstate(basis_cls, basis_class_specific_params): bas = CombinedBasis().instantiate_basis( 5, basis_cls, basis_class_specific_params, window_size=10 ) - bas2 = CombinedBasis().instantiate_basis(7, basis_cls, basis_class_specific_params, window_size=10) + bas2 = CombinedBasis().instantiate_basis( + 7, basis_cls, basis_class_specific_params, window_size=10 + ) transformer = bas.to_transformer() state = {"_basis": bas2} transformer.__setstate__(state) - assert transformer.basis == bas2 + assert transformer.basis == bas2 + @pytest.mark.parametrize( "basis_cls", list_all_basis_classes(), ) -def test_eetstate(basis_cls, basis_class_specific_params): +def test_getstate(basis_cls, basis_class_specific_params): bas = CombinedBasis().instantiate_basis( 5, basis_cls, basis_class_specific_params, window_size=10 ) @@ -835,13 +841,14 @@ def test_eetstate(basis_cls, basis_class_specific_params): # check all reimplemented methods dict_reimplemented_method = get_subclass_methods(basis_cls) - for meth in dict_abst_method: + for meth in dict_reimplemented_method: assert meth[0] in lst # check that it is a trnasformer for meth in ["fit", "transform", "fit_transform"]: assert meth in lst + @pytest.mark.parametrize( "basis_cls", list_all_basis_classes(), @@ -849,11 +856,20 @@ def test_eetstate(basis_cls, basis_class_specific_params): @pytest.mark.parametrize( "inp, expectation", [ - (np.random.randn(10, 2), pytest.raises(ValueError, match="Input mismatch: expected \d inputs")), - (np.random.randn(10, 3, 1), pytest.raises(ValueError, match="X must be 2-dimensional")), - ({1: np.random.randn(10, 3)}, pytest.raises(ValueError, match="The input must be a 2-dimensional array")), + ( + np.random.randn(10, 2), + pytest.raises(ValueError, match="Input mismatch: expected \d inputs"), + ), + ( + np.random.randn(10, 3, 1), + pytest.raises(ValueError, match="X must be 2-dimensional"), + ), + ( + {1: np.random.randn(10, 3)}, + pytest.raises(ValueError, match="The input must be a 2-dimensional array"), + ), (np.random.randn(10, 3), does_not_raise()), - ] + ], ) @pytest.mark.parametrize("method", ["fit", "transform", "fit_transform"]) def test_check_input(inp, expectation, basis_cls, basis_class_specific_params, method): @@ -863,11 +879,15 @@ def test_check_input(inp, expectation, basis_cls, basis_class_specific_params, m # set kernels bas._set_input_independent_states() # set input shape - transformer = bas.to_transformer().set_input_shape(*([3] * bas._n_input_dimensionality)) + transformer = bas.to_transformer().set_input_shape( + *([3] * bas._n_input_dimensionality) + ) if isinstance(bas, (AdditiveBasis, MultiplicativeBasis)): if hasattr(inp, "ndim"): ndim = inp.ndim - inp = np.concatenate([inp.reshape(inp.shape[0], -1)] * bas._n_input_dimensionality, axis=1) + inp = np.concatenate( + [inp.reshape(inp.shape[0], -1)] * bas._n_input_dimensionality, axis=1 + ) if ndim == 3: inp = inp[..., np.newaxis]