From 63b04e4ed4edffbd341e6c5bf796933718297484 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 5 Dec 2024 16:54:52 -0500 Subject: [PATCH 01/37] added improved unpacking of transformer basis. --- src/nemos/basis/_basis.py | 16 ++++++++------- src/nemos/basis/_basis_mixin.py | 23 +++++++++++++++++++++ src/nemos/basis/_transformer_basis.py | 29 +++++++++++++++++++-------- 3 files changed, 53 insertions(+), 15 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index fb9ee3cd..9dd55f28 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -820,6 +820,10 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): self._n_output_features = self.n_basis_funcs * self._n_basis_input[0] return self + def _list_components(self): + return [self] + + class AdditiveBasis(CompositeBasisMixin, Basis): """ @@ -857,16 +861,15 @@ class AdditiveBasis(CompositeBasisMixin, Basis): def __init__(self, basis1: Basis, basis2: Basis) -> None: self.n_basis_funcs = basis1.n_basis_funcs + basis2.n_basis_funcs - super().__init__(self.n_basis_funcs, mode="eval") + CompositeBasisMixin.__init__(self, basis1, basis2) + Basis.__init__(self, self.n_basis_funcs, mode="eval") self._n_input_dimensionality = ( basis1._n_input_dimensionality + basis2._n_input_dimensionality ) self._n_basis_input = None self._n_output_features = None self._label = "(" + basis1.label + " + " + basis2.label + ")" - self._basis1 = basis1 - self._basis2 = basis2 - CompositeBasisMixin.__init__(self) + def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis: """ @@ -1226,8 +1229,8 @@ class MultiplicativeBasis(CompositeBasisMixin, Basis): def __init__(self, basis1: Basis, basis2: Basis) -> None: self.n_basis_funcs = basis1.n_basis_funcs * basis2.n_basis_funcs - CompositeBasisMixin.__init__(self) - super().__init__(self.n_basis_funcs, mode="eval") + CompositeBasisMixin.__init__(self, basis1, basis2) + Basis.__init__(self, self.n_basis_funcs, mode="eval") self._n_input_dimensionality = ( basis1._n_input_dimensionality + basis2._n_input_dimensionality ) @@ -1237,7 +1240,6 @@ def __init__(self, basis1: Basis, basis2: Basis) -> None: self._basis1 = basis1 self._basis2 = basis2 BasisTransformerMixin.__init__(self) - CompositeBasisMixin.__init__(self) def set_kernel(self, *xi: NDArray) -> Basis: """Call fit on the multiplied basis. diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 9bd1b09c..954f31dc 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -270,6 +270,10 @@ class CompositeBasisMixin: (AdditiveBasis and MultiplicativeBasis). """ + def __init__(self, basis1: Basis, basis2: Basis): + self.basis1 = basis1 + self.basis2 = basis2 + def _check_n_basis_min(self) -> None: pass @@ -300,3 +304,22 @@ def _check_input_shape_consistency(self, *xi: NDArray): self._basis2._check_input_shape_consistency( *xi[self._basis1._n_input_dimensionality :] ) + + @property + def basis1(self): + return self._basis1 + + @basis1.setter + def basis1(self, bas: Basis): + self._basis1 = bas + + @property + def basis2(self): + return self._basis2 + + @basis2.setter + def basis2(self, bas: Basis): + self._basis2 = bas + + def _list_components(self): + return self._basis1._list_components() + self._basis2._list_components() diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index 9068ca9c..2db0e2a7 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -1,7 +1,9 @@ from __future__ import annotations import copy -from typing import TYPE_CHECKING +from typing import List, TYPE_CHECKING + +import numpy as np from ..typing import FeatureMatrix @@ -60,15 +62,19 @@ class TransformerBasis: """ def __init__(self, basis: Basis): + if basis._n_basis_input is None: + raise RuntimeError( + "TransformerBasis initialization failed: the provided basis has no defined input shape. " + "Please call `set_input_shape` on the basis before initializing the transformer." + ) + self._basis = copy.deepcopy(basis) - @staticmethod - def _unpack_inputs(X: FeatureMatrix): - """Unpack impute without using transpose. + def _unpack_inputs(self, X: FeatureMatrix) -> List: + """Unpack inputs. Unpack horizontally stacked inputs using slicing. This works gracefully with ``pynapple``, - returning a list of Tsd objects. Attempt to unpack using *X.T will raise a ``pynapple`` - exception since ``pynapple`` assumes that the time axis is the first axis. + returning a list of Tsd objects. Parameters ---------- @@ -78,10 +84,17 @@ def _unpack_inputs(X: FeatureMatrix): Returns ------- : - A tuple of each individual input. + A list of each individual input. """ - return (X[:, k] for k in range(X.shape[1])) + n_samples = X.shape[0] + out = [] + cc = 0 + for i, bas in enumerate(self._basis._list_components()): + n_input = self._n_basis_input[i] + out.append(np.reshape(X[:, cc:cc + n_input], (n_samples, *bas._input_shape))) + cc += n_input + return out def fit(self, X: FeatureMatrix, y=None): """ From 0785e31de4422cf0f331d51d14c6089b16f38fae Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 5 Dec 2024 17:03:54 -0500 Subject: [PATCH 02/37] added docstrings --- src/nemos/basis/_basis.py | 18 ++++++++++++++++++ src/nemos/basis/_basis_mixin.py | 9 +++++++++ 2 files changed, 27 insertions(+) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index 9dd55f28..fd243d5a 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -821,6 +821,24 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): return self def _list_components(self): + """List all basis components. + + This is re-implemented for composite basis in the mixin class. + + Returns + ------- + A list with all 1d basis components. + + Raises + ------ + RuntimeError + If the basis has multiple components. This would only happen if there is an + implementation issue, for example, if a composite basis is implemented but the + mixin class is not initialized, or if the _list_components method of the composite mixin + class is accidentally removed. + """ + if hasattr(self, "basis1"): + raise RuntimeError("Composite basis must implement the _list_components method.") return [self] diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 954f31dc..da96f2d3 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -322,4 +322,13 @@ def basis2(self, bas: Basis): self._basis2 = bas def _list_components(self): + """List all basis components. + + Reimplements the default behavior by iteratively calling _list_components of the + elements. + + Returns + ------- + A list with all 1d basis components. + """ return self._basis1._list_components() + self._basis2._list_components() From e2bea6bbaded6f24337e4e7e4a140cf4d4b706fd Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 5 Dec 2024 17:57:18 -0500 Subject: [PATCH 03/37] linted and fixed shared tests --- src/nemos/basis/_basis.py | 21 ++--- src/nemos/basis/_basis_mixin.py | 16 ++++ src/nemos/basis/_transformer_basis.py | 21 ++++- tests/test_basis.py | 110 ++++++++++++++++---------- 4 files changed, 110 insertions(+), 58 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index fd243d5a..2305f5ea 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -838,11 +838,12 @@ def _list_components(self): class is accidentally removed. """ if hasattr(self, "basis1"): - raise RuntimeError("Composite basis must implement the _list_components method.") + raise RuntimeError( + "Composite basis must implement the _list_components method." + ) return [self] - class AdditiveBasis(CompositeBasisMixin, Basis): """ Class representing the addition of two Basis objects. @@ -879,15 +880,13 @@ class AdditiveBasis(CompositeBasisMixin, Basis): def __init__(self, basis1: Basis, basis2: Basis) -> None: self.n_basis_funcs = basis1.n_basis_funcs + basis2.n_basis_funcs - CompositeBasisMixin.__init__(self, basis1, basis2) Basis.__init__(self, self.n_basis_funcs, mode="eval") + self._label = "(" + basis1.label + " + " + basis2.label + ")" + CompositeBasisMixin.__init__(self, basis1, basis2) + self._n_input_dimensionality = ( basis1._n_input_dimensionality + basis2._n_input_dimensionality ) - self._n_basis_input = None - self._n_output_features = None - self._label = "(" + basis1.label + " + " + basis2.label + ")" - def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis: """ @@ -1247,16 +1246,12 @@ class MultiplicativeBasis(CompositeBasisMixin, Basis): def __init__(self, basis1: Basis, basis2: Basis) -> None: self.n_basis_funcs = basis1.n_basis_funcs * basis2.n_basis_funcs - CompositeBasisMixin.__init__(self, basis1, basis2) Basis.__init__(self, self.n_basis_funcs, mode="eval") + self._label = "(" + basis1.label + " * " + basis2.label + ")" + CompositeBasisMixin.__init__(self, basis1, basis2) self._n_input_dimensionality = ( basis1._n_input_dimensionality + basis2._n_input_dimensionality ) - self._n_basis_input = None - self._n_output_features = None - self._label = "(" + basis1.label + " * " + basis2.label + ")" - self._basis1 = basis1 - self._basis2 = basis2 BasisTransformerMixin.__init__(self) def set_kernel(self, *xi: NDArray) -> Basis: diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index da96f2d3..d5ae69c6 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -4,6 +4,7 @@ import copy import inspect +import warnings from typing import TYPE_CHECKING, Optional, Tuple, Union import numpy as np @@ -273,6 +274,21 @@ class CompositeBasisMixin: def __init__(self, basis1: Basis, basis2: Basis): self.basis1 = basis1 self.basis2 = basis2 + shapes = ( + *(bas1._input_shape for bas1 in basis1._list_components()), + *(bas2._input_shape for bas2 in basis2._list_components()), + ) + # if all bases where set, then set input for composition. + set_bases = (s is not None for s in shapes) + if all(set_bases): + # pass down the input shapes + self.set_input_shape(*shapes) + elif any(set_bases): + warnings.warn( + "Only some of the basis where initialized with `set_input_shape`, " + "please initialize the composite basis before computing features.", + category=UserWarning, + ) def _check_n_basis_min(self) -> None: pass diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index 2db0e2a7..108fee6c 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -1,7 +1,7 @@ from __future__ import annotations import copy -from typing import List, TYPE_CHECKING +from typing import TYPE_CHECKING, List import numpy as np @@ -62,13 +62,24 @@ class TransformerBasis: """ def __init__(self, basis: Basis): + self.basis = copy.deepcopy(basis) + + @staticmethod + def _check_initialized(basis): if basis._n_basis_input is None: raise RuntimeError( "TransformerBasis initialization failed: the provided basis has no defined input shape. " "Please call `set_input_shape` on the basis before initializing the transformer." ) - self._basis = copy.deepcopy(basis) + @property + def basis(self): + return self._basis + + @basis.setter + def basis(self, basis): + self._check_initialized(basis) + self._basis = basis def _unpack_inputs(self, X: FeatureMatrix) -> List: """Unpack inputs. @@ -92,7 +103,9 @@ def _unpack_inputs(self, X: FeatureMatrix) -> List: cc = 0 for i, bas in enumerate(self._basis._list_components()): n_input = self._n_basis_input[i] - out.append(np.reshape(X[:, cc:cc + n_input], (n_samples, *bas._input_shape))) + out.append( + np.reshape(X[:, cc : cc + n_input], (n_samples, *bas._input_shape)) + ) cc += n_input return out @@ -276,7 +289,7 @@ def __setattr__(self, name: str, value) -> None: ValueError('Only setting _basis or existing attributes of _basis is allowed.') """ # allow self._basis = basis - if name == "_basis": + if name == "_basis" or name == "basis": super().__setattr__(name, value) # allow changing existing attributes of self._basis elif hasattr(self._basis, name): diff --git a/tests/test_basis.py b/tests/test_basis.py index 6b7a235a..1389a407 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -297,7 +297,6 @@ def cls(self): pass -# Auto-generated file with stripped classes and shared methods @pytest.mark.parametrize( "cls", [ @@ -1249,6 +1248,7 @@ def test_transform_fails(self, cls): def test_transformer_get_params(self, cls): bas = cls["eval"](n_basis_funcs=5, **extra_decay_rates(cls["eval"], 5)) + bas.set_input_shape(*([1] * bas._n_basis_input)) bas_transformer = bas.to_transformer() params_transf = bas_transformer.get_params() params_transf.pop("_basis") @@ -3592,8 +3592,9 @@ def test_basis_to_transformer(basis_cls, class_specific_params): bas = CombinedBasis().instantiate_basis( n_basis_funcs, basis_cls, class_specific_params, window_size=10 ) - - trans_bas = bas.to_transformer() + trans_bas = bas.set_input_shape( + *([1] * bas._n_input_dimensionality) + ).to_transformer() assert isinstance(trans_bas, basis.TransformerBasis) @@ -3619,13 +3620,18 @@ def test_transformer_has_the_same_public_attributes_as_basis( public_attrs_basis = {attr for attr in dir(bas) if not attr.startswith("_")} public_attrs_transformerbasis = { - attr for attr in dir(bas.to_transformer()) if not attr.startswith("_") + attr + for attr in dir( + bas.set_input_shape(*([1] * bas._n_input_dimensionality)).to_transformer() + ) + if not attr.startswith("_") } assert public_attrs_transformerbasis - public_attrs_basis == { "fit", "fit_transform", "transform", + "basis", } assert public_attrs_basis - public_attrs_transformerbasis == set() @@ -3642,7 +3648,7 @@ def test_to_transformer_and_constructor_are_equivalent( bas = CombinedBasis().instantiate_basis( n_basis_funcs, basis_cls, class_specific_params, window_size=10 ) - + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) trans_bas_a = bas.to_transformer() trans_bas_b = basis.TransformerBasis(bas) @@ -3668,7 +3674,9 @@ def test_basis_to_transformer_makes_a_copy(basis_cls, class_specific_params): bas_a = CombinedBasis().instantiate_basis( 5, basis_cls, class_specific_params, window_size=10 ) - trans_bas_a = bas_a.to_transformer() + trans_bas_a = bas_a.set_input_shape( + *([1] * bas_a._n_input_dimensionality) + ).to_transformer() # changing an attribute in bas should not change trans_bas if basis_cls in [AdditiveBasis, MultiplicativeBasis]: @@ -3679,6 +3687,7 @@ def test_basis_to_transformer_makes_a_copy(basis_cls, class_specific_params): bas_b = CombinedBasis().instantiate_basis( 5, basis_cls, class_specific_params, window_size=10 ) + bas_b.set_input_shape(*([1] * bas_b._n_input_dimensionality)) trans_bas_b = bas_b.to_transformer() trans_bas_b._basis._basis1.n_basis_funcs = 100 assert bas_b._basis1.n_basis_funcs == 5 @@ -3690,7 +3699,9 @@ def test_basis_to_transformer_makes_a_copy(basis_cls, class_specific_params): bas_b = CombinedBasis().instantiate_basis( 5, basis_cls, class_specific_params, window_size=10 ) - trans_bas_b = bas_b.to_transformer() + trans_bas_b = bas_b.set_input_shape( + *([1] * bas_b._n_input_dimensionality) + ).to_transformer() trans_bas_b.n_basis_funcs = 100 assert bas_b.n_basis_funcs == 5 @@ -3701,10 +3712,11 @@ def test_basis_to_transformer_makes_a_copy(basis_cls, class_specific_params): ) @pytest.mark.parametrize("n_basis_funcs", [5, 10, 20]) def test_transformerbasis_getattr(basis_cls, n_basis_funcs, class_specific_params): + bas = CombinedBasis().instantiate_basis( + n_basis_funcs, basis_cls, class_specific_params, window_size=10 + ) trans_basis = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - n_basis_funcs, basis_cls, class_specific_params, window_size=10 - ) + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) ) if basis_cls in [AdditiveBasis, MultiplicativeBasis]: for bas in [ @@ -3724,10 +3736,11 @@ def test_transformerbasis_getattr(basis_cls, n_basis_funcs, class_specific_param def test_transformerbasis_set_params( basis_cls, n_basis_funcs_init, n_basis_funcs_new, class_specific_params ): + bas = CombinedBasis().instantiate_basis( + n_basis_funcs_init, basis_cls, class_specific_params, window_size=10 + ) trans_basis = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - n_basis_funcs_init, basis_cls, class_specific_params, window_size=10 - ) + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) ) trans_basis.set_params(n_basis_funcs=n_basis_funcs_new) @@ -3741,15 +3754,19 @@ def test_transformerbasis_set_params( ) def test_transformerbasis_setattr_basis(basis_cls, class_specific_params): # setting the _basis attribute should change it + bas = CombinedBasis().instantiate_basis( + 10, basis_cls, class_specific_params, window_size=10 + ) trans_bas = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - 10, basis_cls, class_specific_params, window_size=10 - ) + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) ) - trans_bas._basis = CombinedBasis().instantiate_basis( + + bas = CombinedBasis().instantiate_basis( 20, basis_cls, class_specific_params, window_size=10 ) + trans_bas.basis = bas.set_input_shape(*([1] * bas._n_input_dimensionality)) + assert trans_bas.n_basis_funcs == 20 assert trans_bas._basis.n_basis_funcs == 20 assert isinstance(trans_bas._basis, basis_cls) @@ -3762,10 +3779,11 @@ def test_transformerbasis_setattr_basis(basis_cls, class_specific_params): def test_transformerbasis_setattr_basis_attribute(basis_cls, class_specific_params): # setting an attribute that is an attribute of the underlying _basis # should propagate setting it on _basis itself + bas = CombinedBasis().instantiate_basis( + 10, basis_cls, class_specific_params, window_size=10 + ) trans_bas = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - 10, basis_cls, class_specific_params, window_size=10 - ) + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) ) trans_bas.n_basis_funcs = 20 @@ -3784,6 +3802,7 @@ def test_transformerbasis_copy_basis_on_contsruct(basis_cls, class_specific_para orig_bas = CombinedBasis().instantiate_basis( 10, basis_cls, class_specific_params, window_size=10 ) + orig_bas = orig_bas.set_input_shape(*([1] * orig_bas._n_input_dimensionality)) trans_bas = basis.TransformerBasis(orig_bas) trans_bas.n_basis_funcs = 20 @@ -3800,10 +3819,11 @@ def test_transformerbasis_copy_basis_on_contsruct(basis_cls, class_specific_para def test_transformerbasis_setattr_illegal_attribute(basis_cls, class_specific_params): # changing an attribute that is not _basis or an attribute of _basis # is not allowed + bas = CombinedBasis().instantiate_basis( + 10, basis_cls, class_specific_params, window_size=10 + ) trans_bas = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - 10, basis_cls, class_specific_params, window_size=10 - ) + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) ) with pytest.raises( @@ -3823,9 +3843,11 @@ def test_transformerbasis_addition(basis_cls, class_specific_params): bas_a = CombinedBasis().instantiate_basis( n_basis_funcs_a, basis_cls, class_specific_params, window_size=10 ) + bas_a.set_input_shape(*([1] * bas_a._n_input_dimensionality)) bas_b = CombinedBasis().instantiate_basis( n_basis_funcs_b, basis_cls, class_specific_params, window_size=10 ) + bas_b.set_input_shape(*([1] * bas_b._n_input_dimensionality)) trans_bas_a = basis.TransformerBasis(bas_a) trans_bas_b = basis.TransformerBasis(bas_b) trans_bas_sum = trans_bas_a + trans_bas_b @@ -3851,15 +3873,17 @@ def test_transformerbasis_addition(basis_cls, class_specific_params): def test_transformerbasis_multiplication(basis_cls, class_specific_params): n_basis_funcs_a = 5 n_basis_funcs_b = n_basis_funcs_a * 2 + bas1 = CombinedBasis().instantiate_basis( + n_basis_funcs_a, basis_cls, class_specific_params, window_size=10 + ) trans_bas_a = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - n_basis_funcs_a, basis_cls, class_specific_params, window_size=10 - ) + bas1.set_input_shape(*([1] * bas1._n_input_dimensionality)) + ) + bas2 = CombinedBasis().instantiate_basis( + n_basis_funcs_b, basis_cls, class_specific_params, window_size=10 ) trans_bas_b = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - n_basis_funcs_b, basis_cls, class_specific_params, window_size=10 - ) + bas2.set_input_shape(*([1] * bas2._n_input_dimensionality)) ) trans_bas_prod = trans_bas_a * trans_bas_b assert isinstance(trans_bas_prod, basis.TransformerBasis) @@ -3893,10 +3917,11 @@ def test_transformerbasis_multiplication(basis_cls, class_specific_params): def test_transformerbasis_exponentiation( basis_cls, exponent: int, error_type, error_message, class_specific_params ): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, class_specific_params, window_size=10 + ) trans_bas = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - 5, basis_cls, class_specific_params, window_size=10 - ) + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) ) if not isinstance(exponent, int): @@ -3911,10 +3936,11 @@ def test_transformerbasis_exponentiation( list_all_basis_classes(), ) def test_transformerbasis_dir(basis_cls, class_specific_params): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, class_specific_params, window_size=10 + ) trans_bas = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - 5, basis_cls, class_specific_params, window_size=10 - ) + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) ) for attr_name in ( "fit", @@ -3940,6 +3966,7 @@ def test_transformerbasis_sk_clone_kernel_noned(basis_cls, class_specific_params orig_bas = CombinedBasis().instantiate_basis( 10, basis_cls, class_specific_params, window_size=20 ) + orig_bas.set_input_shape(*([1] * orig_bas._n_input_dimensionality)) trans_bas = basis.TransformerBasis(orig_bas) # kernel should be saved in the object after fit @@ -3963,11 +3990,12 @@ def test_transformerbasis_sk_clone_kernel_noned(basis_cls, class_specific_params def test_transformerbasis_pickle( tmpdir, basis_cls, n_basis_funcs, class_specific_params ): + bas = CombinedBasis().instantiate_basis( + n_basis_funcs, basis_cls, class_specific_params, window_size=10 + ) # the test that tries cross-validation with n_jobs = 2 already should test this trans_bas = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - n_basis_funcs, basis_cls, class_specific_params, window_size=10 - ) + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) ) filepath = tmpdir / "transformerbasis.pickle" with open(filepath, "wb") as f: @@ -4109,13 +4137,13 @@ def test_multi_epoch_pynapple_basis_transformer( n_input = bas._n_input_dimensionality - # pass through transformer - bas = basis.TransformerBasis(bas) - # concat input X = pynapple_concatenate_numpy([tsd[:, None]] * n_input, axis=1) # run convolutions + # pass through transformer + bas.set_input_shape(X) + bas = basis.TransformerBasis(bas) res = bas.fit_transform(X) # check nans From 67476294da6a92f18a480efa8a02ea8e884ef7f0 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 5 Dec 2024 17:59:10 -0500 Subject: [PATCH 04/37] fixed shared tests --- tests/test_basis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index 1389a407..1b38dca0 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -1248,7 +1248,7 @@ def test_transform_fails(self, cls): def test_transformer_get_params(self, cls): bas = cls["eval"](n_basis_funcs=5, **extra_decay_rates(cls["eval"], 5)) - bas.set_input_shape(*([1] * bas._n_basis_input)) + bas.set_input_shape( *([1] * bas._n_input_dimensionality)) bas_transformer = bas.to_transformer() params_transf = bas_transformer.get_params() params_transf.pop("_basis") From 909ead1bc7dde378d383c80fbbfca7ddb1b8bdb6 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 6 Dec 2024 12:15:30 -0500 Subject: [PATCH 05/37] added tests for new functionalities --- src/nemos/basis/_basis.py | 80 ++++++---- src/nemos/basis/_basis_mixin.py | 24 ++- src/nemos/basis/_decaying_exponential.py | 4 - src/nemos/basis/_raised_cosine_basis.py | 6 - src/nemos/basis/_spline_basis.py | 17 --- src/nemos/basis/basis.py | 66 ++++++--- tests/test_basis.py | 181 ++++++++++++++++++++++- 7 files changed, 292 insertions(+), 86 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index 2305f5ea..ec9a865e 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -111,8 +111,6 @@ class Basis(Base, abc.ABC, BasisTransformerMixin): Parameters ---------- - n_basis_funcs : - The number of basis functions. mode : The mode of operation. 'eval' for evaluation at sample points, 'conv' for convolutional operation. @@ -135,27 +133,26 @@ class Basis(Base, abc.ABC, BasisTransformerMixin): def __init__( self, - n_basis_funcs: int, mode: Literal["eval", "conv"] = "eval", label: Optional[str] = None, ) -> None: - self.n_basis_funcs = n_basis_funcs + self._n_basis_funcs = getattr(self, "_n_basis_funcs", None) self._n_input_dimensionality = 0 self._mode = mode - self._n_basis_input = None - - # these parameters are going to be set at the first call of `compute_features` - # since we cannot know a-priori how many features may be convolved - self._n_output_features = None - self._input_shape = None - if label is None: self._label = self.__class__.__name__ else: self._label = str(label) + self._check_n_basis_min() + + # specified only after inputs/input shapes are provided + self._n_basis_input = None + self._input_shape = None + + # set by set_kernel self.kernel_ = None @property @@ -169,7 +166,9 @@ def n_output_features(self) -> int | None: provided to the basis is known. Therefore, before the first call to ``compute_features``, this property will return ``None``. After that call, ``n_output_features`` will be available. """ - return self._n_output_features + if self._n_basis_input is not None: + return self.n_basis_funcs * self._n_basis_input[0] + return None @property def label(self) -> str: @@ -633,9 +632,7 @@ def _get_default_slicing( # should we remove this option? if self._n_basis_input[0] == 1 or isinstance(self, MultiplicativeBasis): split_dict = { - self.label: slice( - start_slice, start_slice + self._n_output_features - ) + self.label: slice(start_slice, start_slice + self.n_output_features) } else: split_dict = { @@ -649,9 +646,9 @@ def _get_default_slicing( } else: split_dict = { - self.label: slice(start_slice, start_slice + self._n_output_features) + self.label: slice(start_slice, start_slice + self.n_output_features) } - start_slice += self._n_output_features + start_slice += self.n_output_features return split_dict, start_slice def split_by_feature( @@ -817,7 +814,6 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): self._input_shape = shape self._n_basis_input = n_inputs - self._n_output_features = self.n_basis_funcs * self._n_basis_input[0] return self def _list_components(self): @@ -879,15 +875,32 @@ class AdditiveBasis(CompositeBasisMixin, Basis): """ def __init__(self, basis1: Basis, basis2: Basis) -> None: - self.n_basis_funcs = basis1.n_basis_funcs + basis2.n_basis_funcs - Basis.__init__(self, self.n_basis_funcs, mode="eval") - self._label = "(" + basis1.label + " + " + basis2.label + ")" CompositeBasisMixin.__init__(self, basis1, basis2) + Basis.__init__(self, mode="eval") + self._label = "(" + basis1.label + " + " + basis2.label + ")" self._n_input_dimensionality = ( basis1._n_input_dimensionality + basis2._n_input_dimensionality ) + @property + def n_basis_funcs(self): + """Compute the n-basis function runtime. + + This plays well with cross-validation where the number of basis function of the + underlying bases can be changed. It must be read-only since the number of basis + is determined by the two basis elements and the type of composition. + """ + return self.basis1.n_basis_funcs + self.basis2.n_basis_funcs + + @property + def n_output_features(self): + out1 = getattr(self._basis1, "n_output_features", None) + out2 = getattr(self._basis2, "n_output_features", None) + if out1 is None or out2 is None: + return None + return out1 + out2 + def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis: """ Set the expected input shape for the basis object. @@ -945,9 +958,6 @@ def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis: *xi[self._basis1._n_input_dimensionality :] )._n_basis_input, ) - self._n_output_features = ( - self._basis1.n_output_features + self._basis2.n_output_features - ) return self @support_pynapple(conv_type="numpy") @@ -1245,15 +1255,28 @@ class MultiplicativeBasis(CompositeBasisMixin, Basis): """ def __init__(self, basis1: Basis, basis2: Basis) -> None: - self.n_basis_funcs = basis1.n_basis_funcs * basis2.n_basis_funcs - Basis.__init__(self, self.n_basis_funcs, mode="eval") - self._label = "(" + basis1.label + " * " + basis2.label + ")" CompositeBasisMixin.__init__(self, basis1, basis2) + Basis.__init__(self, mode="eval") + self._label = "(" + basis1.label + " * " + basis2.label + ")" self._n_input_dimensionality = ( basis1._n_input_dimensionality + basis2._n_input_dimensionality ) BasisTransformerMixin.__init__(self) + @property + def n_basis_funcs(self): + """Compute the n-basis function runtime. + + This plays well with cross-validation where the number of basis function of the + underlying bases can be changed. It must be read-only since the number of basis + is determined by the two basis elements and the type of composition. + """ + return self.basis1.n_basis_funcs * self.basis2.n_basis_funcs + + @property + def n_output_features(self): + return self._basis1.n_output_features * self._basis2.n_output_features + def set_kernel(self, *xi: NDArray) -> Basis: """Call fit on the multiplied basis. @@ -1395,9 +1418,6 @@ def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis: *xi[self._basis1._n_input_dimensionality :] )._n_basis_input, ) - self._n_output_features = ( - self._basis1.n_output_features * self._basis2.n_output_features - ) return self def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index d5ae69c6..dad43593 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -2,6 +2,7 @@ from __future__ import annotations +import abc import copy import inspect import warnings @@ -20,8 +21,11 @@ class EvalBasisMixin: """Mixin class for evaluational basis.""" - def __init__(self, bounds: Optional[Tuple[float, float]] = None): + def __init__( + self, n_basis_funcs: int, bounds: Optional[Tuple[float, float]] = None + ): self.bounds = bounds + self._n_basis_funcs = n_basis_funcs def _compute_features(self, *xi: NDArray): """ @@ -89,9 +93,12 @@ def bounds(self, values: Union[None, Tuple[float, float]]): class ConvBasisMixin: """Mixin class for convolutional basis.""" - def __init__(self, window_size: int, conv_kwargs: Optional[dict] = None): + def __init__( + self, n_basis_funcs: int, window_size: int, conv_kwargs: Optional[dict] = None + ): self.window_size = window_size self.conv_kwargs = {} if conv_kwargs is None else conv_kwargs + self._n_basis_funcs = n_basis_funcs def _compute_features(self, *xi: NDArray): """ @@ -272,14 +279,17 @@ class CompositeBasisMixin: """ def __init__(self, basis1: Basis, basis2: Basis): - self.basis1 = basis1 - self.basis2 = basis2 + # deep copy to avoid changes directly to the 1d basis to be reflected + # in the composite basis. + self.basis1 = copy.deepcopy(basis1) + self.basis2 = copy.deepcopy(basis2) shapes = ( *(bas1._input_shape for bas1 in basis1._list_components()), *(bas2._input_shape for bas2 in basis2._list_components()), ) # if all bases where set, then set input for composition. set_bases = (s is not None for s in shapes) + if all(set_bases): # pass down the input shapes self.set_input_shape(*shapes) @@ -290,6 +300,12 @@ def __init__(self, basis1: Basis, basis2: Basis): category=UserWarning, ) + @property + @abc.abstractmethod + def n_basis_funcs(self): + """Read only property for composite bases.""" + pass + def _check_n_basis_min(self) -> None: pass diff --git a/src/nemos/basis/_decaying_exponential.py b/src/nemos/basis/_decaying_exponential.py index 679c9d64..ea7ea711 100644 --- a/src/nemos/basis/_decaying_exponential.py +++ b/src/nemos/basis/_decaying_exponential.py @@ -20,8 +20,6 @@ class OrthExponentialBasis(Basis, abc.ABC): Parameters ---------- - n_basis_funcs - Number of basis functions. decay_rates : Decay rates of the exponentials, shape ``(n_basis_funcs,)``. mode : @@ -34,13 +32,11 @@ class OrthExponentialBasis(Basis, abc.ABC): def __init__( self, - n_basis_funcs: int, decay_rates: NDArray[np.floating], mode="eval", label: Optional[str] = "OrthExponentialBasis", ): super().__init__( - n_basis_funcs, mode=mode, label=label, ) diff --git a/src/nemos/basis/_raised_cosine_basis.py b/src/nemos/basis/_raised_cosine_basis.py index 3b70b2ff..145b260c 100644 --- a/src/nemos/basis/_raised_cosine_basis.py +++ b/src/nemos/basis/_raised_cosine_basis.py @@ -20,8 +20,6 @@ class RaisedCosineBasisLinear(Basis, abc.ABC): Parameters ---------- - n_basis_funcs : - The number of basis functions. mode : The mode of operation. 'eval' for evaluation at sample points, 'conv' for convolutional operation. @@ -41,13 +39,11 @@ class RaisedCosineBasisLinear(Basis, abc.ABC): def __init__( self, - n_basis_funcs: int, mode="eval", width: float = 2.0, label: Optional[str] = "RaisedCosineBasisLinear", ) -> None: super().__init__( - n_basis_funcs, mode=mode, label=label, ) @@ -233,7 +229,6 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear, abc.ABC): def __init__( self, - n_basis_funcs: int, mode="eval", width: float = 2.0, time_scaling: float = None, @@ -241,7 +236,6 @@ def __init__( label: Optional[str] = "RaisedCosineBasisLog", ) -> None: super().__init__( - n_basis_funcs, mode=mode, width=width, label=label, diff --git a/src/nemos/basis/_spline_basis.py b/src/nemos/basis/_spline_basis.py index 3d4bcb22..6060ce5b 100644 --- a/src/nemos/basis/_spline_basis.py +++ b/src/nemos/basis/_spline_basis.py @@ -20,8 +20,6 @@ class SplineBasis(Basis, abc.ABC): Parameters ---------- - n_basis_funcs : - Number of basis functions. mode : The mode of operation. 'eval' for evaluation at sample points, 'conv' for convolutional operation. @@ -39,14 +37,12 @@ class SplineBasis(Basis, abc.ABC): def __init__( self, - n_basis_funcs: int, order: int = 2, label: Optional[str] = None, mode: Literal["conv", "eval"] = "eval", ) -> None: self.order = order super().__init__( - n_basis_funcs, label=label, mode=mode, ) @@ -157,9 +153,6 @@ class MSplineBasis(SplineBasis, abc.ABC): Parameters ---------- - n_basis_funcs : - The number of basis functions to generate. More basis functions allow for - more flexible data modeling but can lead to overfitting. mode : The mode of operation. 'eval' for evaluation at sample points, 'conv' for convolutional operation. @@ -197,13 +190,11 @@ class MSplineBasis(SplineBasis, abc.ABC): def __init__( self, - n_basis_funcs: int, mode: Literal["eval", "conv"] = "eval", order: int = 2, label: Optional[str] = "MSplineEval", ) -> None: super().__init__( - n_basis_funcs, mode=mode, order=order, label=label, @@ -298,8 +289,6 @@ class BSplineBasis(SplineBasis, abc.ABC): Parameters ---------- - n_basis_funcs : - Number of basis functions. mode : The mode of operation. ``'eval'`` for evaluation at sample points, 'conv' for convolutional operation. @@ -325,13 +314,11 @@ class BSplineBasis(SplineBasis, abc.ABC): def __init__( self, - n_basis_funcs: int, mode="eval", order: int = 4, label: Optional[str] = "BSplineBasis", ): super().__init__( - n_basis_funcs, mode=mode, order=order, label=label, @@ -414,8 +401,6 @@ class CyclicBSplineBasis(SplineBasis, abc.ABC): Parameters ---------- - n_basis_funcs : - Number of basis functions. mode : The mode of operation. 'eval' for evaluation at sample points, 'conv' for convolutional operation. @@ -437,13 +422,11 @@ class CyclicBSplineBasis(SplineBasis, abc.ABC): def __init__( self, - n_basis_funcs: int, mode="eval", order: int = 4, label: Optional[str] = "CyclicBSplineBasis", ): super().__init__( - n_basis_funcs, mode=mode, order=order, label=label, diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 88a57901..720f5811 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -83,10 +83,9 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "BSplineEval", ): - EvalBasisMixin.__init__(self, bounds=bounds) + EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds) BSplineBasis.__init__( self, - n_basis_funcs, mode="eval", order=order, label=label, @@ -231,10 +230,14 @@ def __init__( label: Optional[str] = "BSplineConv", conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) + ConvBasisMixin.__init__( + self, + n_basis_funcs=n_basis_funcs, + window_size=window_size, + conv_kwargs=conv_kwargs, + ) BSplineBasis.__init__( self, - n_basis_funcs, mode="conv", order=order, label=label, @@ -372,10 +375,9 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "CyclicBSplineEval", ): - EvalBasisMixin.__init__(self, bounds=bounds) + EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds) CyclicBSplineBasis.__init__( self, - n_basis_funcs, mode="eval", order=order, label=label, @@ -512,10 +514,14 @@ def __init__( label: Optional[str] = "CyclicBSplineConv", conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) + ConvBasisMixin.__init__( + self, + n_basis_funcs=n_basis_funcs, + window_size=window_size, + conv_kwargs=conv_kwargs, + ) CyclicBSplineBasis.__init__( self, - n_basis_funcs, mode="conv", order=order, label=label, @@ -677,10 +683,9 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "MSplineEval", ): - EvalBasisMixin.__init__(self, bounds=bounds) + EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds) MSplineBasis.__init__( self, - n_basis_funcs, mode="eval", order=order, label=label, @@ -841,10 +846,14 @@ def __init__( label: Optional[str] = "MSplineConv", conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) + ConvBasisMixin.__init__( + self, + n_basis_funcs=n_basis_funcs, + window_size=window_size, + conv_kwargs=conv_kwargs, + ) MSplineBasis.__init__( self, - n_basis_funcs, mode="conv", order=order, label=label, @@ -992,10 +1001,9 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "RaisedCosineLinearEval", ): - EvalBasisMixin.__init__(self, bounds=bounds) + EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds) RaisedCosineBasisLinear.__init__( self, - n_basis_funcs, width=width, mode="eval", label=label, @@ -1135,10 +1143,14 @@ def __init__( label: Optional[str] = "RaisedCosineLinearConv", conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) + ConvBasisMixin.__init__( + self, + n_basis_funcs=n_basis_funcs, + window_size=window_size, + conv_kwargs=conv_kwargs, + ) RaisedCosineBasisLinear.__init__( self, - n_basis_funcs, mode="conv", width=width, label=label, @@ -1285,10 +1297,9 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "RaisedCosineLogEval", ): - EvalBasisMixin.__init__(self, bounds=bounds) + EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds) RaisedCosineBasisLog.__init__( self, - n_basis_funcs, width=width, time_scaling=time_scaling, enforce_decay_to_zero=enforce_decay_to_zero, @@ -1438,10 +1449,14 @@ def __init__( label: Optional[str] = "RaisedCosineLogConv", conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) + ConvBasisMixin.__init__( + self, + n_basis_funcs=n_basis_funcs, + window_size=window_size, + conv_kwargs=conv_kwargs, + ) RaisedCosineBasisLog.__init__( self, - n_basis_funcs, mode="conv", width=width, time_scaling=time_scaling, @@ -1576,10 +1591,9 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "OrthExponentialEval", ): - EvalBasisMixin.__init__(self, bounds=bounds) + EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds) OrthExponentialBasis.__init__( self, - n_basis_funcs, decay_rates=decay_rates, mode="eval", label=label, @@ -1713,10 +1727,14 @@ def __init__( label: Optional[str] = "OrthExponentialConv", conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) + ConvBasisMixin.__init__( + self, + n_basis_funcs=n_basis_funcs, + window_size=window_size, + conv_kwargs=conv_kwargs, + ) OrthExponentialBasis.__init__( self, - n_basis_funcs, mode="conv", decay_rates=decay_rates, label=label, diff --git a/tests/test_basis.py b/tests/test_basis.py index 1b38dca0..593e8152 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -1248,7 +1248,7 @@ def test_transform_fails(self, cls): def test_transformer_get_params(self, cls): bas = cls["eval"](n_basis_funcs=5, **extra_decay_rates(cls["eval"], 5)) - bas.set_input_shape( *([1] * bas._n_input_dimensionality)) + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) bas_transformer = bas.to_transformer() params_transf = bas_transformer.get_params() params_transf.pop("_basis") @@ -2730,6 +2730,86 @@ def test_set_input_value_types( with expectation: add.set_input_shape(*inp_shape) + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + def test_deep_copy_basis(self, basis_a, basis_b, class_specific_params): + basis_a = self.instantiate_basis( + 5, basis_a, class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, class_specific_params, window_size=10 + ) + add = basis_a + basis_b + # test pointing to different objects + assert id(add.basis1) != id(basis_a) + assert id(add.basis1) != id(basis_b) + assert id(add.basis2) != id(basis_a) + assert id(add.basis2) != id(basis_b) + + # test attributes are not related + basis_a.n_basis_funcs = 10 + basis_b.n_basis_funcs = 10 + assert add.basis1.n_basis_funcs == 5 + assert add.basis2.n_basis_funcs == 5 + + add.basis1.n_basis_funcs = 6 + add.basis2.n_basis_funcs = 6 + assert basis_a.n_basis_funcs == 10 + assert basis_b.n_basis_funcs == 10 + + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + def test_compute_n_basis_runtime(self, basis_a, basis_b, class_specific_params): + basis_a = self.instantiate_basis( + 5, basis_a, class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, class_specific_params, window_size=10 + ) + add = basis_a + basis_b + add.basis1.n_basis_funcs = 10 + assert add.n_basis_funcs == 15 + add.basis2.n_basis_funcs = 10 + assert add.n_basis_funcs == 20 + + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes() + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes() + ) + def test_runtime_n_basis_out_compute(self, basis_a, basis_b, class_specific_params): + basis_a = self.instantiate_basis( + 5, basis_a, class_specific_params, window_size=10 + ) + basis_a.set_input_shape(*([1] * basis_a._n_input_dimensionality)).to_transformer() + basis_b = self.instantiate_basis( + 5, basis_b, class_specific_params, window_size=10 + ) + basis_b.set_input_shape(*([1] * basis_b._n_input_dimensionality)).to_transformer() + add = basis_a + basis_b + inps_a = [2] * basis_a._n_input_dimensionality + add.basis1.set_input_shape(*inps_a) + if isinstance(basis_a, MultiplicativeBasis): + new_out_num = np.prod(inps_a) * add.basis1.n_basis_funcs + else: + new_out_num = inps_a[0] * add.basis1.n_basis_funcs + assert add.n_output_features == new_out_num + add.basis2.n_basis_funcs + inps_b = [3] * basis_b._n_input_dimensionality + if isinstance(basis_b, MultiplicativeBasis): + new_out_num_b = np.prod(inps_b) * add.basis2.n_basis_funcs + else: + new_out_num_b = inps_b[0] * add.basis2.n_basis_funcs + add.basis2.set_input_shape(*inps_b) + assert add.n_output_features == new_out_num + new_out_num_b class TestMultiplicativeBasis(CombinedBasis): cls = {"eval": MultiplicativeBasis, "conv": MultiplicativeBasis} @@ -3536,6 +3616,87 @@ def test_set_input_value_types( with expectation: mul.set_input_shape(*inp_shape) + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + def test_deep_copy_basis(self, basis_a, basis_b, class_specific_params): + basis_a = self.instantiate_basis( + 5, basis_a, class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, class_specific_params, window_size=10 + ) + mul = basis_a * basis_b + # test pointing to different objects + assert id(mul.basis1) != id(basis_a) + assert id(mul.basis1) != id(basis_b) + assert id(mul.basis2) != id(basis_a) + assert id(mul.basis2) != id(basis_b) + + # test attributes are not related + basis_a.n_basis_funcs = 10 + basis_b.n_basis_funcs = 10 + assert mul.basis1.n_basis_funcs == 5 + assert mul.basis2.n_basis_funcs == 5 + + mul.basis1.n_basis_funcs = 6 + mul.basis2.n_basis_funcs = 6 + assert basis_a.n_basis_funcs == 10 + assert basis_b.n_basis_funcs == 10 + + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + def test_compute_n_basis_runtime(self, basis_a, basis_b, class_specific_params): + basis_a = self.instantiate_basis( + 5, basis_a, class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, class_specific_params, window_size=10 + ) + mul = basis_a * basis_b + mul.basis1.n_basis_funcs = 10 + assert mul.n_basis_funcs == 50 + mul.basis2.n_basis_funcs = 10 + assert mul.n_basis_funcs == 100 + + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes() + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes() + ) + def test_runtime_n_basis_out_compute(self, basis_a, basis_b, class_specific_params): + basis_a = self.instantiate_basis( + 5, basis_a, class_specific_params, window_size=10 + ) + basis_a.set_input_shape(*([1] * basis_a._n_input_dimensionality)).to_transformer() + basis_b = self.instantiate_basis( + 5, basis_b, class_specific_params, window_size=10 + ) + basis_b.set_input_shape(*([1] * basis_b._n_input_dimensionality)).to_transformer() + mul = basis_a * basis_b + inps_a = [2] * basis_a._n_input_dimensionality + mul.basis1.set_input_shape(*inps_a) + if isinstance(basis_a, MultiplicativeBasis): + new_out_num = np.prod(inps_a) * mul.basis1.n_basis_funcs + else: + new_out_num = inps_a[0] * mul.basis1.n_basis_funcs + assert mul.n_output_features == new_out_num * mul.basis2.n_basis_funcs + inps_b = [3] * basis_b._n_input_dimensionality + if isinstance(basis_b, MultiplicativeBasis): + new_out_num_b = np.prod(inps_b) * mul.basis2.n_basis_funcs + else: + new_out_num_b = inps_b[0] * mul.basis2.n_basis_funcs + mul.basis2.set_input_shape(*inps_b) + assert mul.n_output_features == new_out_num * new_out_num_b + @pytest.mark.parametrize( "exponent", [-1, 0, 0.5, basis.RaisedCosineLogEval(4), 1, 2, 3] @@ -3583,6 +3744,8 @@ def test_power_of_basis(exponent, basis_class, class_specific_params): assert np.all(np.isnan(out[~non_nan])) + + @pytest.mark.parametrize( "basis_cls", list_all_basis_classes(), @@ -4246,12 +4409,21 @@ def test__get_splitter( bas1_instance = combine_basis.instantiate_basis( n_basis[0], bas1, class_specific_params, window_size=10, label="1" ) + bas1_instance.set_input_shape( + *([n_input_basis[0]] * bas1_instance._n_input_dimensionality) + ) bas2_instance = combine_basis.instantiate_basis( n_basis[1], bas2, class_specific_params, window_size=10, label="2" ) + bas2_instance.set_input_shape( + *([n_input_basis[1]] * bas2_instance._n_input_dimensionality) + ) bas3_instance = combine_basis.instantiate_basis( n_basis[2], bas3, class_specific_params, window_size=10, label="3" ) + bas3_instance.set_input_shape( + *([n_input_basis[2]] * bas3_instance._n_input_dimensionality) + ) func1 = getattr(bas1_instance, operator1) func2 = getattr(bas2_instance, operator2) @@ -4401,9 +4573,16 @@ def test__get_splitter_split_by_input( bas1_instance = combine_basis.instantiate_basis( n_basis[0], bas1, class_specific_params, window_size=10, label="1" ) + bas1_instance.set_input_shape( + *([n_input_basis_1] * bas1_instance._n_input_dimensionality) + ) + bas2_instance = combine_basis.instantiate_basis( n_basis[1], bas2, class_specific_params, window_size=10, label="2" ) + bas2_instance.set_input_shape( + *([n_input_basis_2] * bas1_instance._n_input_dimensionality) + ) func1 = getattr(bas1_instance, operator) bas12 = func1(bas2_instance) From f31999f9c63aeaf37f7077bbf7299583fb72648a Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 6 Dec 2024 12:46:38 -0500 Subject: [PATCH 06/37] broken Transformer --- src/nemos/basis/_basis.py | 53 +++++++++---------- src/nemos/basis/_basis_mixin.py | 2 +- src/nemos/basis/_transformer_basis.py | 6 +-- tests/test_basis.py | 76 +++++++++++++-------------- 4 files changed, 68 insertions(+), 69 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index ec9a865e..c5bf1aa5 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -149,7 +149,7 @@ def __init__( self._check_n_basis_min() # specified only after inputs/input shapes are provided - self._n_basis_input = None + self._n_basis_input_ = None self._input_shape = None # set by set_kernel @@ -166,8 +166,8 @@ def n_output_features(self) -> int | None: provided to the basis is known. Therefore, before the first call to ``compute_features``, this property will return ``None``. After that call, ``n_output_features`` will be available. """ - if self._n_basis_input is not None: - return self.n_basis_funcs * self._n_basis_input[0] + if self._n_basis_input_ is not None: + return self.n_basis_funcs * self._n_basis_input_[0] return None @property @@ -176,12 +176,12 @@ def label(self) -> str: return self._label @property - def n_basis_input(self) -> tuple | None: + def n_basis_input_(self) -> tuple | None: """Number of expected inputs. The number of inputs ``compute_feature`` expects. """ - return self._n_basis_input + return self._n_basis_input_ @property def n_basis_funcs(self): @@ -270,7 +270,7 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: Subclasses should implement how to handle the transformation specific to their basis function types and operation modes. """ - if self._n_basis_input is None: + if self._n_basis_input_ is None: self.set_input_shape(*xi) self._check_input_shape_consistency(*xi) self.set_kernel() @@ -558,7 +558,7 @@ def _get_feature_slicing( Parameters ---------- n_inputs : - The number of input basis for each component, by default it uses ``self._n_basis_input``. + The number of input basis for each component, by default it uses ``self._n_basis_input_``. start_slice : The starting index for slicing, by default it starts from 0. split_by_input : @@ -580,18 +580,18 @@ def _get_feature_slicing( _merge_slicing_dicts : Merges multiple slicing dictionaries, handling keys conflicts. """ # Set default values for n_inputs and start_slice if not provided - n_inputs = n_inputs or self._n_basis_input + n_inputs = n_inputs or self._n_basis_input_ start_slice = start_slice or 0 # If the instance is of AdditiveBasis type, handle slicing for the additive components if isinstance(self, AdditiveBasis): split_dict, start_slice = self._basis1._get_feature_slicing( - n_inputs[: len(self._basis1._n_basis_input)], + n_inputs[: len(self._basis1._n_basis_input_)], start_slice, split_by_input=split_by_input, ) sp2, start_slice = self._basis2._get_feature_slicing( - n_inputs[len(self._basis1._n_basis_input) :], + n_inputs[len(self._basis1._n_basis_input_):], start_slice, split_by_input=split_by_input, ) @@ -630,7 +630,7 @@ def _get_default_slicing( """Handle default slicing logic.""" if split_by_input: # should we remove this option? - if self._n_basis_input[0] == 1 or isinstance(self, MultiplicativeBasis): + if self._n_basis_input_[0] == 1 or isinstance(self, MultiplicativeBasis): split_dict = { self.label: slice(start_slice, start_slice + self.n_output_features) } @@ -641,7 +641,7 @@ def _get_default_slicing( start_slice + i * self.n_basis_funcs, start_slice + (i + 1) * self.n_basis_funcs, ) - for i in range(self._n_basis_input[0]) + for i in range(self._n_basis_input_[0]) } } else: @@ -739,13 +739,13 @@ def is_leaf(val): # Apply the slicing using the custom leaf function out = jax.tree_util.tree_map(lambda sl: x[sl], index_dict, is_leaf=is_leaf) - # reshape the arrays to spilt by n_basis_input + # reshape the arrays to spilt by n_basis_input_ reshaped_out = dict() for i, vals in enumerate(out.items()): key, val = vals shape = list(val.shape) reshaped_out[key] = val.reshape( - shape[:axis] + [self._n_basis_input[i], -1] + shape[axis + 1 :] + shape[:axis] + [self._n_basis_input_[i], -1] + shape[axis + 1:] ) return reshaped_out @@ -813,7 +813,7 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): self._input_shape = shape - self._n_basis_input = n_inputs + self._n_basis_input_ = n_inputs return self def _list_components(self): @@ -950,13 +950,13 @@ def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis: 181 """ - self._n_basis_input = ( + self._n_basis_input_ = ( *self._basis1.set_input_shape( *xi[: self._basis1._n_input_dimensionality] - )._n_basis_input, + )._n_basis_input_, *self._basis2.set_input_shape( *xi[self._basis1._n_input_dimensionality :] - )._n_basis_input, + )._n_basis_input_, ) return self @@ -1231,11 +1231,6 @@ class MultiplicativeBasis(CompositeBasisMixin, Basis): basis2 : Second basis object to multiply. - Attributes - ---------- - n_basis_funcs : - Number of basis functions. - Examples -------- >>> # Generate sample data @@ -1275,7 +1270,11 @@ def n_basis_funcs(self): @property def n_output_features(self): - return self._basis1.n_output_features * self._basis2.n_output_features + out1 = getattr(self._basis1, "n_output_features", None) + out2 = getattr(self._basis2, "n_output_features", None) + if out1 is None or out2 is None: + return None + return out1 * out2 def set_kernel(self, *xi: NDArray) -> Basis: """Call fit on the multiplied basis. @@ -1410,13 +1409,13 @@ def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis: 25200 """ - self._n_basis_input = ( + self._n_basis_input_ = ( *self._basis1.set_input_shape( *xi[: self._basis1._n_input_dimensionality] - )._n_basis_input, + )._n_basis_input_, *self._basis2.set_input_shape( *xi[self._basis1._n_input_dimensionality :] - )._n_basis_input, + )._n_basis_input_, ) return self diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index dad43593..52031c98 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -268,7 +268,7 @@ def to_transformer(self) -> TransformerBasis: ... ) >>> gridsearch = gridsearch.fit(X, y) """ - return TransformerBasis(copy.deepcopy(self)) + return TransformerBasis(self) class CompositeBasisMixin: diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index 108fee6c..e9a566fd 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -62,11 +62,11 @@ class TransformerBasis: """ def __init__(self, basis: Basis): - self.basis = copy.deepcopy(basis) + self._basis = copy.deepcopy(basis) @staticmethod def _check_initialized(basis): - if basis._n_basis_input is None: + if basis._n_basis_input_ is None: raise RuntimeError( "TransformerBasis initialization failed: the provided basis has no defined input shape. " "Please call `set_input_shape` on the basis before initializing the transformer." @@ -102,7 +102,7 @@ def _unpack_inputs(self, X: FeatureMatrix) -> List: out = [] cc = 0 for i, bas in enumerate(self._basis._list_components()): - n_input = self._n_basis_input[i] + n_input = self._n_basis_input_[i] out.append( np.reshape(X[:, cc : cc + n_input], (n_samples, *bas._input_shape)) ) diff --git a/tests/test_basis.py b/tests/test_basis.py index 593e8152..8a4d7965 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -347,7 +347,7 @@ def test_call_vmin_vmax(self, samples, vmin, vmax, expectation, cls): [ ("label", None), ("label", "label"), - ("n_basis_input", 1), + ("n_basis_input_", 1), ("n_output_features", 5), ], ) @@ -441,10 +441,10 @@ def test_set_num_basis_input(self, n_input, cls): bas = cls["conv"]( n_basis_funcs=5, window_size=10, **extra_decay_rates(cls["conv"], 5) ) - assert bas.n_basis_input is None + assert bas.n_basis_input_ is None bas.compute_features(np.random.randn(20, n_input)) - assert bas.n_basis_input == (n_input,) - assert bas._n_basis_input == (n_input,) + assert bas.n_basis_input_ == (n_input,) + assert bas._n_basis_input_ == (n_input,) @pytest.mark.parametrize( "bounds, samples, nan_idx, mn, mx", @@ -2552,11 +2552,11 @@ def test_set_num_basis_input(self, n_basis_input1, n_basis_input2): bas1 = basis.RaisedCosineLinearConv(10, window_size=10) bas2 = basis.BSplineConv(10, window_size=10) bas_add = bas1 + bas2 - assert bas_add.n_basis_input is None + assert bas_add.n_basis_input_ is None bas_add.compute_features( np.ones((20, n_basis_input1)), np.ones((20, n_basis_input2)) ) - assert bas_add.n_basis_input == (n_basis_input1, n_basis_input2) + assert bas_add.n_basis_input_ == (n_basis_input1, n_basis_input2) @pytest.mark.parametrize( "n_input, expectation", @@ -3427,11 +3427,11 @@ def test_set_num_basis_input(self, n_basis_input1, n_basis_input2): bas1 = basis.RaisedCosineLinearConv(10, window_size=10) bas2 = basis.BSplineConv(10, window_size=10) bas_add = bas1 * bas2 - assert bas_add.n_basis_input is None + assert bas_add.n_basis_input_ is None bas_add.compute_features( np.ones((20, n_basis_input1)), np.ones((20, n_basis_input2)) ) - assert bas_add.n_basis_input == (n_basis_input1, n_basis_input2) + assert bas_add.n_basis_input_ == (n_basis_input1, n_basis_input2) @pytest.mark.parametrize( "n_input, expectation", @@ -3453,14 +3453,14 @@ def test_expected_input_number(self, n_input, expectation): @pytest.mark.parametrize("n_basis_input1", [1, 2, 3]) @pytest.mark.parametrize("n_basis_input2", [1, 2, 3]) - def test_n_basis_input(self, n_basis_input1, n_basis_input2): + def test_n_basis_input_(self, n_basis_input1, n_basis_input2): bas1 = basis.RaisedCosineLinearConv(10, window_size=10) bas2 = basis.BSplineConv(10, window_size=10) bas_prod = bas1 * bas2 bas_prod.compute_features( np.ones((20, n_basis_input1)), np.ones((20, n_basis_input2)) ) - assert bas_prod.n_basis_input == (n_basis_input1, n_basis_input2) + assert bas_prod.n_basis_input_ == (n_basis_input1, n_basis_input2) @pytest.mark.parametrize( "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") @@ -4329,18 +4329,18 @@ def test_multi_epoch_pynapple_basis_transformer( "__add__", "__add__", lambda bas1, bas2, bas3: { - "1": slice(0, bas1._n_basis_input[0] * bas1.n_basis_funcs), + "1": slice(0, bas1._n_basis_input_[0] * bas1.n_basis_funcs), "2": slice( - bas1._n_basis_input[0] * bas1.n_basis_funcs, - bas1._n_basis_input[0] * bas1.n_basis_funcs - + bas2._n_basis_input[0] * bas2.n_basis_funcs, + bas1._n_basis_input_[0] * bas1.n_basis_funcs, + bas1._n_basis_input_[0] * bas1.n_basis_funcs + + bas2._n_basis_input_[0] * bas2.n_basis_funcs, ), "3": slice( - bas1._n_basis_input[0] * bas1.n_basis_funcs - + bas2._n_basis_input[0] * bas2.n_basis_funcs, - bas1._n_basis_input[0] * bas1.n_basis_funcs - + bas2._n_basis_input[0] * bas2.n_basis_funcs - + bas3._n_basis_input[0] * bas3.n_basis_funcs, + bas1._n_basis_input_[0] * bas1.n_basis_funcs + + bas2._n_basis_input_[0] * bas2.n_basis_funcs, + bas1._n_basis_input_[0] * bas1.n_basis_funcs + + bas2._n_basis_input_[0] * bas2.n_basis_funcs + + bas3._n_basis_input_[0] * bas3.n_basis_funcs, ), }, ), @@ -4348,13 +4348,13 @@ def test_multi_epoch_pynapple_basis_transformer( "__add__", "__mul__", lambda bas1, bas2, bas3: { - "1": slice(0, bas1._n_basis_input[0] * bas1.n_basis_funcs), + "1": slice(0, bas1._n_basis_input_[0] * bas1.n_basis_funcs), "(2 * 3)": slice( - bas1._n_basis_input[0] * bas1.n_basis_funcs, - bas1._n_basis_input[0] * bas1.n_basis_funcs - + bas2._n_basis_input[0] + bas1._n_basis_input_[0] * bas1.n_basis_funcs, + bas1._n_basis_input_[0] * bas1.n_basis_funcs + + bas2._n_basis_input_[0] * bas2.n_basis_funcs - * bas3._n_basis_input[0] + * bas3._n_basis_input_[0] * bas3.n_basis_funcs, ), }, @@ -4366,11 +4366,11 @@ def test_multi_epoch_pynapple_basis_transformer( # note that it doesn't respect algebra order but execute right to left (first add then multiplies) "(1 * (2 + 3))": slice( 0, - bas1._n_basis_input[0] + bas1._n_basis_input_[0] * bas1.n_basis_funcs * ( - bas2._n_basis_input[0] * bas2.n_basis_funcs - + bas3._n_basis_input[0] * bas3.n_basis_funcs + bas2._n_basis_input_[0] * bas2.n_basis_funcs + + bas3._n_basis_input_[0] * bas3.n_basis_funcs ), ), }, @@ -4381,11 +4381,11 @@ def test_multi_epoch_pynapple_basis_transformer( lambda bas1, bas2, bas3: { "(1 * (2 * 3))": slice( 0, - bas1._n_basis_input[0] + bas1._n_basis_input_[0] * bas1.n_basis_funcs - * bas2._n_basis_input[0] + * bas2._n_basis_input_[0] * bas2.n_basis_funcs - * bas3._n_basis_input[0] + * bas3._n_basis_input_[0] * bas3.n_basis_funcs, ), }, @@ -4448,11 +4448,11 @@ def test__get_splitter( 1, 1, lambda bas1, bas2: { - "1": slice(0, bas1._n_basis_input[0] * bas1.n_basis_funcs), + "1": slice(0, bas1._n_basis_input_[0] * bas1.n_basis_funcs), "2": slice( - bas1._n_basis_input[0] * bas1.n_basis_funcs, - bas1._n_basis_input[0] * bas1.n_basis_funcs - + bas2._n_basis_input[0] * bas2.n_basis_funcs, + bas1._n_basis_input_[0] * bas1.n_basis_funcs, + bas1._n_basis_input_[0] * bas1.n_basis_funcs + + bas2._n_basis_input_[0] * bas2.n_basis_funcs, ), }, ), @@ -4463,9 +4463,9 @@ def test__get_splitter( lambda bas1, bas2: { "(1 * 2)": slice( 0, - bas1._n_basis_input[0] + bas1._n_basis_input_[0] * bas1.n_basis_funcs - * bas2._n_basis_input[0] + * bas2._n_basis_input_[0] * bas2.n_basis_funcs, ) }, @@ -4490,7 +4490,7 @@ def test__get_splitter( 1, lambda bas1, bas2: { "(1 * 2)": slice( - 0, bas1._n_basis_input[0] * bas1.n_basis_funcs * bas2.n_basis_funcs + 0, bas1._n_basis_input_[0] * bas1.n_basis_funcs * bas2.n_basis_funcs ) }, ), @@ -4517,7 +4517,7 @@ def test__get_splitter( 2, lambda bas1, bas2: { "(1 * 2)": slice( - 0, bas2._n_basis_input[0] * bas1.n_basis_funcs * bas2.n_basis_funcs + 0, bas2._n_basis_input_[0] * bas1.n_basis_funcs * bas2.n_basis_funcs ) }, ), From 7c7ccdce2352bf34993d7eb6a8e4ad546d368deb Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 6 Dec 2024 12:56:10 -0500 Subject: [PATCH 07/37] modified all attrs that gets set by basis method to match the sklearn naming convention --- src/nemos/basis/_basis.py | 10 +++++----- src/nemos/basis/_basis_mixin.py | 4 ++-- src/nemos/basis/_transformer_basis.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index c5bf1aa5..7d54e156 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -150,7 +150,7 @@ def __init__( # specified only after inputs/input shapes are provided self._n_basis_input_ = None - self._input_shape = None + self._input_shape_ = None # set by set_kernel self.kernel_ = None @@ -754,10 +754,10 @@ def _check_input_shape_consistency(self, x: NDArray): # remove sample axis and squeeze shape = x.shape[1:] - initialized = self._input_shape is not None - is_shape_match = self._input_shape == shape + initialized = self._input_shape_ is not None + is_shape_match = self._input_shape_ == shape if initialized and not is_shape_match: - expected_shape_str = "(n_samples, " + f"{self._input_shape}"[1:] + expected_shape_str = "(n_samples, " + f"{self._input_shape_}"[1:] expected_shape_str = expected_shape_str.replace(",)", ")") raise ValueError( f"Input shape mismatch detected.\n\n" @@ -811,7 +811,7 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): n_inputs = (int(np.prod(shape)),) - self._input_shape = shape + self._input_shape_ = shape self._n_basis_input_ = n_inputs return self diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 52031c98..360d6f25 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -284,8 +284,8 @@ def __init__(self, basis1: Basis, basis2: Basis): self.basis1 = copy.deepcopy(basis1) self.basis2 = copy.deepcopy(basis2) shapes = ( - *(bas1._input_shape for bas1 in basis1._list_components()), - *(bas2._input_shape for bas2 in basis2._list_components()), + *(bas1._input_shape_ for bas1 in basis1._list_components()), + *(bas2._input_shape_ for bas2 in basis2._list_components()), ) # if all bases where set, then set input for composition. set_bases = (s is not None for s in shapes) diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index e9a566fd..12809858 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -104,7 +104,7 @@ def _unpack_inputs(self, X: FeatureMatrix) -> List: for i, bas in enumerate(self._basis._list_components()): n_input = self._n_basis_input_[i] out.append( - np.reshape(X[:, cc : cc + n_input], (n_samples, *bas._input_shape)) + np.reshape(X[:, cc : cc + n_input], (n_samples, *bas._input_shape_)) ) cc += n_input return out From 879a3833ad5a70330f194379c4435ff7587d051d Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 6 Dec 2024 13:59:42 -0500 Subject: [PATCH 08/37] fixed typing --- src/nemos/basis/_basis.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index 7d54e156..2252d2f9 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -53,7 +53,7 @@ def check_one_dimensional(func: Callable) -> Callable: """Check if the input is one-dimensional.""" @wraps(func) - def wrapper(self: Basis, *xi: ArrayLike, **kwargs): + def wrapper(self: Basis, *xi: NDArray, **kwargs): if any(x.ndim != 1 for x in xi): raise ValueError("Input sample must be one dimensional!") return func(self, *xi, **kwargs) @@ -851,11 +851,6 @@ class AdditiveBasis(CompositeBasisMixin, Basis): basis2 : Second basis object to add. - Attributes - ---------- - n_basis_funcs : int - Number of basis functions. - Examples -------- >>> # Generate sample data From 0bff05699557e993e5c557d9bcad789b09f4aa69 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 6 Dec 2024 20:26:18 -0500 Subject: [PATCH 09/37] fixed testing and reset logic --- src/nemos/basis/_basis.py | 257 ++++++++++++++++++-------------- src/nemos/basis/_basis_mixin.py | 147 +++++++++++++++++- src/nemos/basis/basis.py | 15 ++ tests/test_basis.py | 169 +++++++++++++-------- 4 files changed, 411 insertions(+), 177 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index 2252d2f9..ec0f5be3 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -4,7 +4,7 @@ import abc import copy from functools import wraps -from typing import Callable, Generator, Literal, Optional, Tuple, Union +from typing import Any, Callable, Generator, Literal, Optional, Tuple, Union import jax import numpy as np @@ -137,7 +137,7 @@ def __init__( label: Optional[str] = None, ) -> None: self._n_basis_funcs = getattr(self, "_n_basis_funcs", None) - self._n_input_dimensionality = 0 + self._n_input_dimensionality = getattr(self, "_n_input_dimensionality", 0) self._mode = mode @@ -155,6 +155,56 @@ def __init__( # set by set_kernel self.kernel_ = None + # initialize parent to None. This should not end in "_" because it is + # a permanent property of a basis, defined at composite basis init + self._parent = None + + def _recompute_kernels(self): + """Recompute all kernels if needed. + + Traverse the tree upwards and reset all input-independent states. + If the node is the root, directly update its states; otherwise, propagate + the request to the parent node. + """ + # Assumes that state updates in the basis tree can be handled independently for each node. + # This is currently true but may change if dependencies are introduced. + # The only such state is self.kernel_, which is set independently for each basis component. + # If dependencies are introduced, use `self.set_kernel` at the root level instead. + # (A basis is the tree root if self._parend is None). + # Note: `self.set_kernel` is more expensive as it recomputes kernels for the entire tree. + update_states = getattr(self, "_reset_all_input_independent_states", None) + if update_states: + update_states() + if getattr(self, "_parent", None): + self._parent._recompute_kernels() + + def _is_init_params_updated(self, name: str, value: Any): + """Check if an attribute set at initialization have been updated.""" + return name in self._get_param_names() + + def __setattr__(self, name: str, value: Any): + """ + Set to None all attributes ending with '_'. + + This __setattr__ resets all the attributes that are defined by a method + like the `kernel_` or `_n_input_shape_` (states of the basis) when an initialization configuration + is updated. + A Basis class must respect the following naming convention: all names of parameters that are settable + by with a method (like `kernel_` computed in `set_kernel`) must end in "_". + + Parameters + ---------- + name : + The name of the attribute to set. + value : + The value to set the attribute to. + """ + # check if the attribute was defined in the __init__ signature + # and if so, then resets all computable states. + super().__setattr__(name, value) + if self._is_init_params_updated(name, value): + self._recompute_kernels() + @property def n_output_features(self) -> int | None: """ @@ -203,43 +253,6 @@ def mode(self): """Mode of operation, either ``"conv"`` or ``"eval"``.""" return self._mode - @staticmethod - def _apply_identifiability_constraints(X: NDArray): - """Apply identifiability constraints to a design matrix `X`. - - Removes columns from `X` until `[1, X]` is full rank to ensure the uniqueness - of the GLM (Generalized Linear Model) maximum-likelihood solution. This is particularly - crucial for models using bases like BSplines and CyclicBspline, which, due to their - construction, sum to 1 and can cause rank deficiency when combined with an intercept. - - For GLMs, this rank deficiency means that different sets of coefficients might yield - identical predicted rates and log-likelihood, complicating parameter learning, especially - in the absence of regularization. - - Parameters - ---------- - X: - The design matrix before applying the identifiability constraints. - - Returns - ------- - : - The adjusted design matrix with redundant columns dropped and columns mean-centered. - """ - - def add_constant(x): - return np.hstack((np.ones((x.shape[0], 1)), x)) - - rank = np.linalg.matrix_rank(add_constant(X)) - # mean center - X = X - np.nanmean(X, axis=0) - while rank < X.shape[1] + 1: - # drop a column - X = X[:, :-1] - # recompute rank - rank = np.linalg.matrix_rank(add_constant(X)) - return X - @check_transform_input def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: """ @@ -278,14 +291,107 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: @abc.abstractmethod def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix: - """Convolve or evaluate the basis.""" + """Convolve or evaluate the basis. + + This method is intended to be equivalent to the sklearn transformer ``transform`` method. + As the latter, it computes the transformation assuming that all the states are already + pre-computed by ``_fit_basis``, a method corresponding to ``fit``. + + The method differs from transformer's ``transform`` for the structure of the input that it accepts. + In particular, ``_compute_features`` accepts a number of different time series, one per 1D basis component, + while ``transform`` requires all inputs to be concatenated in a single array. + """ + pass + + @abc.abstractmethod + def _fit_basis(self, *xi: ArrayLike) -> FeatureMatrix: + """Pre-compute all basis state variables. + + This method is intended to be equivalent to the sklearn transformer ``fit`` method. + As the latter, it computes all the state attributes, and store it with the convention + that the attribute name **must** end with "_", for example ``self.kernel_``, + ``self._input_shape_``. + + The method differs from transformer's ``fit`` for the structure of the input that it accepts. + In particular, ``_fit_basis`` accepts a number of different time series, one per 1D basis component, + while ``fit`` requires all inputs to be concatenated in a single array. + """ pass @abc.abstractmethod def set_kernel(self): - """Set kernel for conv basis and return self or just return self for eval.""" + """Set kernel for conv basis and return self or just return self for eval. + + For the basis API to work correctly, specifically, for the `_fit_basis` + method to work as intended, this method should set **all** state attributes + that do not require inspection of input time series. + + This method currently "just" sets the kernel because this is the only such state + but if in the future new states will be added, they must be funneled through this + method. + + Note that the name of this method can and should be refactored in case more such + states will be set in the future. + """ pass + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): + """ + Set the expected input shape for the basis object. + + This method configures the shape of the input data that the basis object expects. + ``xi`` can be specified as an integer, a tuple of integers, or derived + from an array. The method also calculates the total number of input + features and output features based on the number of basis functions. + + Parameters + ---------- + xi : + The input shape specification. + - An integer: Represents the dimensionality of the input. A value of ``1`` is treated as scalar input. + - A tuple: Represents the exact input shape excluding the first axis (sample axis). + All elements must be integers. + - An array: The shape is extracted, excluding the first axis (assumed to be the sample axis). + + Raises + ------ + ValueError + If a tuple is provided and it contains non-integer elements. + + Returns + ------- + self : + Returns the instance itself to allow method chaining. + + Notes + ----- + All state attributes that depends on the input must be set in this method in order for + the API of basis to work correctly. In particular, this method is called by ``_basis_fit``, + which is equivalent to ``fit`` for a transformer. If any input dependent state + is not set in this method, then ``compute_features`` (equivalent to ``fit_transform``) will break. + + Separating states related to the input (settable with this method) and states that are unrelated + from the input (settable with ``set_kernel``) is a deliberate design choice that improves modularity. + + """ + if isinstance(xi, tuple): + if not all(isinstance(i, int) for i in xi): + raise ValueError( + f"The tuple provided contains non integer values. Tuple: {xi}." + ) + shape = xi + elif isinstance(xi, int): + shape = () if xi == 1 else (xi,) + else: + shape = xi.shape[1:] + + n_inputs = (int(np.prod(shape)),) + + self._input_shape_ = shape + + self._n_basis_input_ = n_inputs + return self + @abc.abstractmethod def _evaluate(self, *xi: ArrayLike) -> FeatureMatrix: """ @@ -591,7 +697,7 @@ def _get_feature_slicing( split_by_input=split_by_input, ) sp2, start_slice = self._basis2._get_feature_slicing( - n_inputs[len(self._basis1._n_basis_input_):], + n_inputs[len(self._basis1._n_basis_input_) :], start_slice, split_by_input=split_by_input, ) @@ -745,7 +851,7 @@ def is_leaf(val): key, val = vals shape = list(val.shape) reshaped_out[key] = val.reshape( - shape[:axis] + [self._n_basis_input_[i], -1] + shape[axis + 1:] + shape[:axis] + [self._n_basis_input_[i], -1] + shape[axis + 1 :] ) return reshaped_out @@ -770,52 +876,6 @@ def _check_input_shape_consistency(self, x: NDArray): "different shape, please create a new basis instance." ) - def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): - """ - Set the expected input shape for the basis object. - - This method configures the shape of the input data that the basis object expects. - ``xi`` can be specified as an integer, a tuple of integers, or derived - from an array. The method also calculates the total number of input - features and output features based on the number of basis functions. - - Parameters - ---------- - xi : - The input shape specification. - - An integer: Represents the dimensionality of the input. A value of ``1`` is treated as scalar input. - - A tuple: Represents the exact input shape excluding the first axis (sample axis). - All elements must be integers. - - An array: The shape is extracted, excluding the first axis (assumed to be the sample axis). - - Raises - ------ - ValueError - If a tuple is provided and it contains non-integer elements. - - Returns - ------- - self : - Returns the instance itself to allow method chaining. - """ - if isinstance(xi, tuple): - if not all(isinstance(i, int) for i in xi): - raise ValueError( - f"The tuple provided contains non integer values. Tuple: {xi}." - ) - shape = xi - elif isinstance(xi, int): - shape = () if xi == 1 else (xi,) - else: - shape = xi.shape[1:] - - n_inputs = (int(np.prod(shape)),) - - self._input_shape_ = shape - - self._n_basis_input_ = n_inputs - return self - def _list_components(self): """List all basis components. @@ -1271,25 +1331,6 @@ def n_output_features(self): return None return out1 * out2 - def set_kernel(self, *xi: NDArray) -> Basis: - """Call fit on the multiplied basis. - - If any of the added basis is in "conv" mode, it will prepare its kernels for the convolution. - - Parameters - ---------- - *xi: - The sample inputs. Unused, necessary to conform to ``scikit-learn`` API. - - Returns - ------- - : - The MultiplicativeBasis ready to be evaluated. - """ - self._basis1.set_kernel() - self._basis2.set_kernel() - return self - @support_pynapple(conv_type="numpy") @check_transform_input @check_one_dimensional diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 360d6f25..f123272b 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -50,6 +50,30 @@ def _compute_features(self, *xi: NDArray): out = self._evaluate(*(np.reshape(x, (x.shape[0], -1)) for x in xi)) return np.reshape(out, (out.shape[0], -1)) + def _fit_basis(self, *xi: NDArray) -> Basis: + """ + Set all basis states. + + This method corresponds sklearn transformer ``fit``. As fit, it must receive the input and + it must set all basis states, i.e. kernel_ and all the states relative to the input shape. + The difference between this method and the transformer ``fit`` is in the expected input structure, + where the transformer ``fit`` method requires the inputs to be concatenated in a 2D array, while here + each input is provided as a separate time series for each basis element. + + Parameters + ---------- + xi: + Input arrays. + + Returns + ------- + : + The basis with ready for evaluation. + """ + self.set_kernel() + self.set_input_shape(*xi) + return self + def set_kernel(self) -> "EvalBasisMixin": """ Prepare or compute the convolutional kernel for the basis functions. @@ -65,6 +89,15 @@ def set_kernel(self) -> "EvalBasisMixin": """ return self + def _reset_all_input_independent_states(self): + """Set all states that are input independent for self only. + + This method sets all the input independent states. This reimplements an abstract method + of basis, and it is different from ``set_kernel`` because it won't traverse the basis + tree in any basis (including composite basis), while ``set_kernel`` applies to all the tree. + """ + return + @property def bounds(self): """Range of values covered by the basis.""" @@ -119,6 +152,14 @@ def _compute_features(self, *xi: NDArray): The input samples over which to apply the basis transformation. The samples can be passed as multiple arguments, each representing a different dimension for multivariate inputs. + Notes + ----- + This method is intended to be 1-to-1 mappable to sklearn ``transform`` method of transformer. This + means that for the method to be callable, all the state attributes have to be pre-computed in a + method that is mappable to ``fit``, which for us is ``_fit_basis``. It is fundamental that both + methods behaves like the corresponding transformer method, with the only difference being the input + structure: a single (X, y) pair for the transformer, a number of time series for the Basis. + """ if self.kernel_ is None: raise ValueError( @@ -132,6 +173,30 @@ def _compute_features(self, *xi: NDArray): # make sure to return a matrix return np.reshape(conv, newshape=(conv.shape[0], -1)) + def _fit_basis(self, *xi: NDArray) -> Basis: + """ + Set all basis states. + + This method corresponds sklearn transformer ``fit``. As fit, it must receive the input and + it must set all basis states, i.e. kernel_ and all the states relative to the input shape. + The difference between this method and the transformer ``fit`` is in the expected input structure, + where the transformer ``fit`` method requires the inputs to be concatenated in a 2D array, while here + each input is provided as a separate time series for each basis element. + + Parameters + ---------- + xi: + Input arrays. + + Returns + ------- + : + The basis with ready for evaluation. + """ + self.set_kernel() + self.set_input_shape(*xi) + return self + def set_kernel(self) -> "ConvBasisMixin": """ Prepare or compute the convolutional kernel for the basis functions. @@ -157,6 +222,33 @@ def set_kernel(self) -> "ConvBasisMixin": self.kernel_ = self._evaluate(np.linspace(0, 1, self.window_size)) return self + def _reset_all_input_independent_states(self): + """Set all states that are input independent for self only. + + This method sets all the input independent states. This reimplements an abstract method + of basis, and it is different from ``set_kernel`` because it won't traverse the basis + tree in any basis (including composite basis), while ``set_kernel`` applies to all the tree. + Called by the setattr of basis. + """ + current_kernel = getattr(self, "kernel_", None) + try: + self.kernel_ = ( + current_kernel + if current_kernel is None + else self._evaluate(np.linspace(0, 1, self.window_size)) + ) + except Exception as e: + # if basis not fully initialized attribute is not there yet. + kernel = getattr(self, "kernel_", None) + if kernel: + warnings.warn( + message=f"Unable to automatically re-initialize the kernel for basis {self.label}, " + f"with exception: {repr(e)}. \n" + f"Resetting the kernel `None`.", + category=UserWarning, + ) + self.kernel_ = None + @property def window_size(self): """Window size as number of samples. @@ -168,6 +260,11 @@ def window_size(self): @window_size.setter def window_size(self, window_size): """Setter for the window size parameter.""" + self._check_window_size(window_size) + + self._window_size = window_size + + def _check_window_size(self, window_size): if window_size is None: raise ValueError( "If the basis is in `conv` mode, you must provide a window_size!" @@ -178,8 +275,6 @@ def window_size(self, window_size): f"`window_size` must be a positive integer. {window_size} provided instead!" ) - self._window_size = window_size - @property def conv_kwargs(self): """The convolutional kwargs. @@ -283,6 +378,11 @@ def __init__(self, basis1: Basis, basis2: Basis): # in the composite basis. self.basis1 = copy.deepcopy(basis1) self.basis2 = copy.deepcopy(basis2) + + # set parents + self.basis1._parent = self + self.basis2._parent = self + shapes = ( *(bas1._input_shape_ for bas1 in basis1._list_components()), *(bas2._input_shape_ for bas2 in basis2._list_components()), @@ -309,20 +409,46 @@ def n_basis_funcs(self): def _check_n_basis_min(self) -> None: pass - def set_kernel(self, *xi: NDArray) -> Basis: + def _fit_basis(self, *xi: NDArray) -> Basis: + """ + Set all basis states. + + This method corresponds sklearn transformer ``fit``. As fit, it must receive the input and + it must set all basis states, i.e. kernel_ and all the states relative to the input shape. + The difference between this method and the transformer ``fit`` is in the expected input structure, + where the transformer ``fit`` method requires the inputs to be concatenated in a 2D array, while here + each input is provided as a separate time series for each basis element. + + Parameters + ---------- + xi: + Input arrays. + + Returns + ------- + : + The basis with ready for evaluation. + """ + self.basis1.set_kernel() + self.basis2.set_kernel() + self.basis1.set_input_shape(*xi[: self._basis1._n_input_dimensionality]) + self.basis2.set_input_shape(*xi[self._basis1._n_input_dimensionality :]) + return self + + def set_kernel(self) -> Basis: """Call set_kernel on the basis elements. If any of the basis elements is in "conv" mode, it will prepare its kernels for the convolution. + Addi + Also grabs input shapes if provided, similar to what sklean transformer `fit` method does Parameters ---------- - *xi: - The sample inputs. Unused, necessary to conform to ``scikit-learn`` API. Returns ------- : - The basis ready to be evaluated. + The basis with the kernels set up. """ self._basis1.set_kernel() self._basis2.set_kernel() @@ -364,3 +490,12 @@ def _list_components(self): A list with all 1d basis components. """ return self._basis1._list_components() + self._basis2._list_components() + + def _reset_all_input_independent_states(self): + """Set all states that are input independent for self only. + + This method sets all the input independent states. This reimplements an abstract method + of basis, and it is different from ``set_kernel`` because it won't traverse the basis + tree in any basis (including composite basis), while ``set_kernel`` applies to all the tree. + """ + return diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 720f5811..916b5f17 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -1739,6 +1739,9 @@ def __init__( decay_rates=decay_rates, label=label, ) + # re-check window size because n_basis_funcs is not set yet when the + # property setter runs the first check. + self._check_window_size(self.window_size) @add_docstring("evaluate_on_grid", OrthExponentialBasis) def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: @@ -1830,3 +1833,15 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ return super().set_input_shape(xi) + + def _check_window_size(self, window_size: int): + """OrthExponentialBasis specific window size check.""" + super()._check_window_size(window_size) + # if n_basis_funcs is not yet initialized, skip check + n_basis = getattr(self, "n_basis_funcs", None) + if n_basis and window_size < n_basis: + raise ValueError( + "OrthExponentialConv basis requires at least a window_size larger then the number " + f"of basis functions. window_size is {window_size}, n_basis_funcs while" + f"is {self.n_basis_funcs}." + ) diff --git a/tests/test_basis.py b/tests/test_basis.py index 8a4d7965..b79102d7 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -519,7 +519,7 @@ def test_vmin_vmax_init(self, bounds, expectation, cls): @pytest.mark.parametrize("n_basis", [6, 7]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_call_basis_number(self, n_basis, mode, kwargs, cls): @@ -551,7 +551,7 @@ def test_call_equivalent_in_conv(self, n_basis, cls): ], ) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) @pytest.mark.parametrize("n_basis", [6]) def test_call_input_num(self, num_input, n_basis, mode, kwargs, expectation, cls): @@ -571,7 +571,7 @@ def test_call_input_num(self, num_input, n_basis, mode, kwargs, expectation, cls ) @pytest.mark.parametrize("n_basis", [6]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_call_input_shape(self, inp, mode, kwargs, expectation, n_basis, cls): bas = cls[mode]( @@ -585,7 +585,7 @@ def test_call_input_shape(self, inp, mode, kwargs, expectation, n_basis, cls): @pytest.mark.parametrize("n_basis", [6]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_call_nan_location(self, mode, kwargs, n_basis, cls): bas = cls[mode]( @@ -618,7 +618,7 @@ def test_call_input_type(self, samples, expectation, n_basis, cls): bas._evaluate(samples) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_call_nan(self, mode, kwargs, cls): bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) @@ -628,7 +628,7 @@ def test_call_nan(self, mode, kwargs, cls): @pytest.mark.parametrize("n_basis", [6, 7]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_call_non_empty(self, n_basis, mode, kwargs, cls): bas = cls[mode]( @@ -639,7 +639,7 @@ def test_call_non_empty(self, n_basis, mode, kwargs, cls): @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_call_sample_axis(self, time_axis_shape, mode, kwargs, cls): bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) @@ -656,7 +656,7 @@ def test_call_sample_axis(self, time_axis_shape, mode, kwargs, cls): ], ) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_call_sample_range(self, mn, mx, expectation, mode, kwargs, cls): bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) @@ -912,7 +912,7 @@ def test_convolution_is_performed(self, cls): @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_evaluate_on_grid_basis_size(self, sample_size, mode, kwargs, cls): if "OrthExp" in cls["eval"].__name__: @@ -931,7 +931,7 @@ def test_evaluate_on_grid_basis_size(self, sample_size, mode, kwargs, cls): @pytest.mark.parametrize("n_input", [0, 1, 2]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_evaluate_on_grid_input_number(self, n_input, mode, kwargs, cls): basis_obj = cls[mode]( @@ -956,7 +956,7 @@ def test_evaluate_on_grid_input_number(self, n_input, mode, kwargs, cls): @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_evaluate_on_grid_meshgrid_size(self, sample_size, mode, kwargs, cls): if "OrthExp" in cls["eval"].__name__: @@ -990,7 +990,7 @@ def test_fit_kernel_shape(self, cls): @pytest.mark.parametrize( "mode, ws, expectation", [ - ("conv", 2, does_not_raise()), + ("conv", 5, does_not_raise()), ( "conv", -1, @@ -1033,9 +1033,9 @@ def test_init_window_size(self, mode, ws, expectation, cls): n_basis_funcs=5, window_size=ws, **extra_decay_rates(cls[mode], 5) ) - @pytest.mark.parametrize("samples", [[], [0], [0, 0]]) + @pytest.mark.parametrize("samples", [[], [0] * 10, [0] * 11]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_non_empty_samples(self, samples, mode, kwargs, cls): if "OrthExp" in cls["eval"].__name__: @@ -1080,7 +1080,7 @@ def test_number_of_required_inputs_compute_features( basis_obj.compute_features(*inputs) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_pynapple_support(self, mode, kwargs, cls): bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) @@ -1239,7 +1239,7 @@ def test_set_window_size(self, mode, expectation, cls): def test_transform_fails(self, cls): bas = cls["conv"]( - n_basis_funcs=5, window_size=3, **extra_decay_rates(cls["conv"], 5) + n_basis_funcs=5, window_size=5, **extra_decay_rates(cls["conv"], 5) ) with pytest.raises( ValueError, match="You must call `_set_kernel` before `_compute_features`" @@ -1353,7 +1353,7 @@ def test_decay_to_zero_basis_number_match(self, width): @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_minimum_number_of_basis_required_is_matched( self, n_basis_funcs, mode, kwargs @@ -1464,7 +1464,7 @@ def test_time_scaling_values(self, time_scaling, expectation, mode, kwargs): ], ) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_width_values(self, width, expectation, mode, kwargs): with expectation: @@ -1476,7 +1476,7 @@ class TestRaisedCosineLinearBasis(BasisFuncsTesting): @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_minimum_number_of_basis_required_is_matched( self, n_basis_funcs, mode, kwargs @@ -1551,7 +1551,7 @@ class TestMSplineBasis(BasisFuncsTesting): @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize("order", [-1, 0, 1, 2, 3, 4, 5]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_minimum_number_of_basis_required_is_matched( self, n_basis_funcs, order, mode, kwargs @@ -1652,6 +1652,50 @@ def test_vmin_vmax_eval_on_grid_scaling_effect_on_eval( class TestOrthExponentialBasis(BasisFuncsTesting): cls = {"eval": basis.OrthExponentialEval, "conv": basis.OrthExponentialConv} + @pytest.mark.parametrize( + "window_size, n_basis, expectation", + [ + ( + 4, + 5, + pytest.raises( + ValueError, + match="OrthExponentialConv basis requires at least a window_size", + ), + ), + (5, 5, does_not_raise()), + ], + ) + def test_window_size_at_init(self, window_size, n_basis, expectation): + decay_rates = np.asarray(np.arange(1, n_basis + 1), dtype=float) + with expectation: + self.cls["conv"](n_basis, decay_rates=decay_rates, window_size=window_size) + + @pytest.mark.parametrize( + "window_size, n_basis, expectation", + [ + ( + 4, + 5, + pytest.raises( + ValueError, + match="OrthExponentialConv basis requires at least a window_size", + ), + ), + (5, 5, does_not_raise()), + ], + ) + def test_window_size_at_init(self, window_size, n_basis, expectation): + decay_rates = np.asarray(np.arange(1, n_basis + 1), dtype=float) + obj = self.cls["conv"]( + n_basis, decay_rates=decay_rates, window_size=n_basis + 1 + ) + with expectation: + obj.window_size = window_size + + with expectation: + obj.set_params(window_size=window_size) + @pytest.mark.parametrize( "decay_rates", [[1, 2, 3], [0.01, 0.02, 0.001], [2, 1, 1, 2.4]] ) @@ -1727,7 +1771,7 @@ class TestBSplineBasis(BasisFuncsTesting): @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize("order", [1, 2, 3, 4, 5]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_minimum_number_of_basis_required_is_matched( self, n_basis_funcs, order, mode, kwargs @@ -1815,7 +1859,7 @@ class TestCyclicBSplineBasis(BasisFuncsTesting): @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize("order", [2, 3, 4, 5]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_minimum_number_of_basis_required_is_matched( self, n_basis_funcs, order, mode, kwargs @@ -2235,7 +2279,7 @@ def test_pynapple_support_compute_features( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) @pytest.mark.parametrize("num_input", [0, 1, 2, 3, 4, 5]) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) def test_call_input_num( self, n_basis_a, @@ -2269,7 +2313,7 @@ def test_call_input_num( (np.linspace(0, 1, 10)[:, None], pytest.raises(ValueError)), ], ) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -2296,7 +2340,7 @@ def test_call_input_shape( basis_obj._evaluate(*([inp] * basis_obj._n_input_dimensionality)) @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -2321,7 +2365,7 @@ def test_call_sample_axis( inp = [np.linspace(0, 1, time_axis_shape)] * basis_obj._n_input_dimensionality assert basis_obj._evaluate(*inp).shape[0] == time_axis_shape - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -2354,10 +2398,10 @@ def test_call_equivalent_in_conv( self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=3 + n_basis_a, basis_a, class_specific_params, window_size=9 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=3 + n_basis_b, basis_b, class_specific_params, window_size=9 ) bas_eva = basis_a_obj + basis_b_obj @@ -2372,7 +2416,7 @@ def test_call_equivalent_in_conv( x = [np.linspace(0, 1, 10)] * bas_con._n_input_dimensionality assert np.all(bas_con._evaluate(*x) == bas_eva._evaluate(*x)) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -2396,7 +2440,7 @@ def test_pynapple_support( assert np.all(y == y_nap.d) assert np.all(y_nap.t == x_nap[0].t) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [6, 7]) @@ -2417,7 +2461,7 @@ def test_call_basis_number( == basis_a_obj.n_basis_funcs + basis_b_obj.n_basis_funcs ) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -2443,7 +2487,7 @@ def test_call_non_empty( (0.1, 2, does_not_raise()), ], ) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -2780,21 +2824,21 @@ def test_compute_n_basis_runtime(self, basis_a, basis_b, class_specific_params): add.basis2.n_basis_funcs = 10 assert add.n_basis_funcs == 20 - @pytest.mark.parametrize( - "basis_a", list_all_basis_classes() - ) - @pytest.mark.parametrize( - "basis_b", list_all_basis_classes() - ) + @pytest.mark.parametrize("basis_a", list_all_basis_classes()) + @pytest.mark.parametrize("basis_b", list_all_basis_classes()) def test_runtime_n_basis_out_compute(self, basis_a, basis_b, class_specific_params): basis_a = self.instantiate_basis( 5, basis_a, class_specific_params, window_size=10 ) - basis_a.set_input_shape(*([1] * basis_a._n_input_dimensionality)).to_transformer() + basis_a.set_input_shape( + *([1] * basis_a._n_input_dimensionality) + ).to_transformer() basis_b = self.instantiate_basis( 5, basis_b, class_specific_params, window_size=10 ) - basis_b.set_input_shape(*([1] * basis_b._n_input_dimensionality)).to_transformer() + basis_b.set_input_shape( + *([1] * basis_b._n_input_dimensionality) + ).to_transformer() add = basis_a + basis_b inps_a = [2] * basis_a._n_input_dimensionality add.basis1.set_input_shape(*inps_a) @@ -2811,6 +2855,7 @@ def test_runtime_n_basis_out_compute(self, basis_a, basis_b, class_specific_para add.basis2.set_input_shape(*inps_b) assert add.n_output_features == new_out_num + new_out_num_b + class TestMultiplicativeBasis(CombinedBasis): cls = {"eval": MultiplicativeBasis, "conv": MultiplicativeBasis} @@ -3110,7 +3155,7 @@ def test_pynapple_support_compute_features( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) @pytest.mark.parametrize("num_input", [0, 1, 2, 3, 4, 5]) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) def test_call_input_num( self, n_basis_a, @@ -3144,7 +3189,7 @@ def test_call_input_num( (np.linspace(0, 1, 10)[:, None], pytest.raises(ValueError)), ], ) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -3171,7 +3216,7 @@ def test_call_input_shape( basis_obj._evaluate(*([inp] * basis_obj._n_input_dimensionality)) @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -3196,7 +3241,7 @@ def test_call_sample_axis( inp = [np.linspace(0, 1, time_axis_shape)] * basis_obj._n_input_dimensionality assert basis_obj._evaluate(*inp).shape[0] == time_axis_shape - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -3247,7 +3292,7 @@ def test_call_equivalent_in_conv( x = [np.linspace(0, 1, 10)] * bas_con._n_input_dimensionality assert np.all(bas_con._evaluate(*x) == bas_eva._evaluate(*x)) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -3271,7 +3316,7 @@ def test_pynapple_support( assert np.all(y == y_nap.d) assert np.all(y_nap.t == x_nap[0].t) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [6, 7]) @@ -3292,7 +3337,7 @@ def test_call_basis_number( == basis_a_obj.n_basis_funcs * basis_b_obj.n_basis_funcs ) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -3318,7 +3363,7 @@ def test_call_non_empty( (0.1, 2, does_not_raise()), ], ) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -3666,21 +3711,21 @@ def test_compute_n_basis_runtime(self, basis_a, basis_b, class_specific_params): mul.basis2.n_basis_funcs = 10 assert mul.n_basis_funcs == 100 - @pytest.mark.parametrize( - "basis_a", list_all_basis_classes() - ) - @pytest.mark.parametrize( - "basis_b", list_all_basis_classes() - ) + @pytest.mark.parametrize("basis_a", list_all_basis_classes()) + @pytest.mark.parametrize("basis_b", list_all_basis_classes()) def test_runtime_n_basis_out_compute(self, basis_a, basis_b, class_specific_params): basis_a = self.instantiate_basis( 5, basis_a, class_specific_params, window_size=10 ) - basis_a.set_input_shape(*([1] * basis_a._n_input_dimensionality)).to_transformer() + basis_a.set_input_shape( + *([1] * basis_a._n_input_dimensionality) + ).to_transformer() basis_b = self.instantiate_basis( 5, basis_b, class_specific_params, window_size=10 ) - basis_b.set_input_shape(*([1] * basis_b._n_input_dimensionality)).to_transformer() + basis_b.set_input_shape( + *([1] * basis_b._n_input_dimensionality) + ).to_transformer() mul = basis_a * basis_b inps_a = [2] * basis_a._n_input_dimensionality mul.basis1.set_input_shape(*inps_a) @@ -3744,8 +3789,6 @@ def test_power_of_basis(exponent, basis_class, class_specific_params): assert np.all(np.isnan(out[~non_nan])) - - @pytest.mark.parametrize( "basis_cls", list_all_basis_classes(), @@ -3876,7 +3919,7 @@ def test_basis_to_transformer_makes_a_copy(basis_cls, class_specific_params): @pytest.mark.parametrize("n_basis_funcs", [5, 10, 20]) def test_transformerbasis_getattr(basis_cls, n_basis_funcs, class_specific_params): bas = CombinedBasis().instantiate_basis( - n_basis_funcs, basis_cls, class_specific_params, window_size=10 + n_basis_funcs, basis_cls, class_specific_params, window_size=30 ) trans_basis = basis.TransformerBasis( bas.set_input_shape(*([1] * bas._n_input_dimensionality)) @@ -3918,14 +3961,14 @@ def test_transformerbasis_set_params( def test_transformerbasis_setattr_basis(basis_cls, class_specific_params): # setting the _basis attribute should change it bas = CombinedBasis().instantiate_basis( - 10, basis_cls, class_specific_params, window_size=10 + 10, basis_cls, class_specific_params, window_size=30 ) trans_bas = basis.TransformerBasis( bas.set_input_shape(*([1] * bas._n_input_dimensionality)) ) bas = CombinedBasis().instantiate_basis( - 20, basis_cls, class_specific_params, window_size=10 + 20, basis_cls, class_specific_params, window_size=30 ) trans_bas.basis = bas.set_input_shape(*([1] * bas._n_input_dimensionality)) @@ -4369,8 +4412,8 @@ def test_multi_epoch_pynapple_basis_transformer( bas1._n_basis_input_[0] * bas1.n_basis_funcs * ( - bas2._n_basis_input_[0] * bas2.n_basis_funcs - + bas3._n_basis_input_[0] * bas3.n_basis_funcs + bas2._n_basis_input_[0] * bas2.n_basis_funcs + + bas3._n_basis_input_[0] * bas3.n_basis_funcs ), ), }, From 861b94677938ae86df9bd91db8b221f19c5df99f Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 9 Dec 2024 10:54:41 -0500 Subject: [PATCH 10/37] allow multi-dim inputs --- src/nemos/basis/_basis.py | 50 ++---------------------------- src/nemos/basis/_basis_mixin.py | 54 +++------------------------------ src/nemos/basis/basis.py | 14 +++++++++ 3 files changed, 20 insertions(+), 98 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index ec0f5be3..3fdc9c59 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -4,7 +4,7 @@ import abc import copy from functools import wraps -from typing import Any, Callable, Generator, Literal, Optional, Tuple, Union +from typing import Callable, Generator, Literal, Optional, Tuple, Union import jax import numpy as np @@ -159,52 +159,6 @@ def __init__( # a permanent property of a basis, defined at composite basis init self._parent = None - def _recompute_kernels(self): - """Recompute all kernels if needed. - - Traverse the tree upwards and reset all input-independent states. - If the node is the root, directly update its states; otherwise, propagate - the request to the parent node. - """ - # Assumes that state updates in the basis tree can be handled independently for each node. - # This is currently true but may change if dependencies are introduced. - # The only such state is self.kernel_, which is set independently for each basis component. - # If dependencies are introduced, use `self.set_kernel` at the root level instead. - # (A basis is the tree root if self._parend is None). - # Note: `self.set_kernel` is more expensive as it recomputes kernels for the entire tree. - update_states = getattr(self, "_reset_all_input_independent_states", None) - if update_states: - update_states() - if getattr(self, "_parent", None): - self._parent._recompute_kernels() - - def _is_init_params_updated(self, name: str, value: Any): - """Check if an attribute set at initialization have been updated.""" - return name in self._get_param_names() - - def __setattr__(self, name: str, value: Any): - """ - Set to None all attributes ending with '_'. - - This __setattr__ resets all the attributes that are defined by a method - like the `kernel_` or `_n_input_shape_` (states of the basis) when an initialization configuration - is updated. - A Basis class must respect the following naming convention: all names of parameters that are settable - by with a method (like `kernel_` computed in `set_kernel`) must end in "_". - - Parameters - ---------- - name : - The name of the attribute to set. - value : - The value to set the attribute to. - """ - # check if the attribute was defined in the __init__ signature - # and if so, then resets all computable states. - super().__setattr__(name, value) - if self._is_init_params_updated(name, value): - self._recompute_kernels() - @property def n_output_features(self) -> int | None: """ @@ -304,7 +258,7 @@ def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix: pass @abc.abstractmethod - def _fit_basis(self, *xi: ArrayLike) -> FeatureMatrix: + def setup_basis(self, *xi: ArrayLike) -> FeatureMatrix: """Pre-compute all basis state variables. This method is intended to be equivalent to the sklearn transformer ``fit`` method. diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index f123272b..ad40b68e 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -50,7 +50,7 @@ def _compute_features(self, *xi: NDArray): out = self._evaluate(*(np.reshape(x, (x.shape[0], -1)) for x in xi)) return np.reshape(out, (out.shape[0], -1)) - def _fit_basis(self, *xi: NDArray) -> Basis: + def set_basis(self, *xi: NDArray) -> Basis: """ Set all basis states. @@ -89,15 +89,6 @@ def set_kernel(self) -> "EvalBasisMixin": """ return self - def _reset_all_input_independent_states(self): - """Set all states that are input independent for self only. - - This method sets all the input independent states. This reimplements an abstract method - of basis, and it is different from ``set_kernel`` because it won't traverse the basis - tree in any basis (including composite basis), while ``set_kernel`` applies to all the tree. - """ - return - @property def bounds(self): """Range of values covered by the basis.""" @@ -173,7 +164,7 @@ def _compute_features(self, *xi: NDArray): # make sure to return a matrix return np.reshape(conv, newshape=(conv.shape[0], -1)) - def _fit_basis(self, *xi: NDArray) -> Basis: + def setup_basis(self, *xi: NDArray) -> Basis: """ Set all basis states. @@ -222,33 +213,6 @@ def set_kernel(self) -> "ConvBasisMixin": self.kernel_ = self._evaluate(np.linspace(0, 1, self.window_size)) return self - def _reset_all_input_independent_states(self): - """Set all states that are input independent for self only. - - This method sets all the input independent states. This reimplements an abstract method - of basis, and it is different from ``set_kernel`` because it won't traverse the basis - tree in any basis (including composite basis), while ``set_kernel`` applies to all the tree. - Called by the setattr of basis. - """ - current_kernel = getattr(self, "kernel_", None) - try: - self.kernel_ = ( - current_kernel - if current_kernel is None - else self._evaluate(np.linspace(0, 1, self.window_size)) - ) - except Exception as e: - # if basis not fully initialized attribute is not there yet. - kernel = getattr(self, "kernel_", None) - if kernel: - warnings.warn( - message=f"Unable to automatically re-initialize the kernel for basis {self.label}, " - f"with exception: {repr(e)}. \n" - f"Resetting the kernel `None`.", - category=UserWarning, - ) - self.kernel_ = None - @property def window_size(self): """Window size as number of samples. @@ -409,7 +373,7 @@ def n_basis_funcs(self): def _check_n_basis_min(self) -> None: pass - def _fit_basis(self, *xi: NDArray) -> Basis: + def setup_basis(self, *xi: NDArray) -> Basis: """ Set all basis states. @@ -429,8 +393,7 @@ def _fit_basis(self, *xi: NDArray) -> Basis: : The basis with ready for evaluation. """ - self.basis1.set_kernel() - self.basis2.set_kernel() + self.set_kernel() self.basis1.set_input_shape(*xi[: self._basis1._n_input_dimensionality]) self.basis2.set_input_shape(*xi[self._basis1._n_input_dimensionality :]) return self @@ -490,12 +453,3 @@ def _list_components(self): A list with all 1d basis components. """ return self._basis1._list_components() + self._basis2._list_components() - - def _reset_all_input_independent_states(self): - """Set all states that are input independent for self only. - - This method sets all the input independent states. This reimplements an abstract method - of basis, and it is different from ``set_kernel`` because it won't traverse the basis - tree in any basis (including composite basis), while ``set_kernel`` applies to all the tree. - """ - return diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 916b5f17..9a144728 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -1845,3 +1845,17 @@ def _check_window_size(self, window_size: int): f"of basis functions. window_size is {window_size}, n_basis_funcs while" f"is {self.n_basis_funcs}." ) + + def set_kernel(self): + try: + super().set_kernel() + except ValueError as e: + if "OrthExponentialBasis requires at least as many" in str(e): + raise ValueError("Cannot set the kernels for OrthExponentialBasis when `window_size` is smaller " + "than `n_basis_funcs.\n" + "Please, increase the window size or reduce the number of basis functions. " + f"Current `window_size` is {self.window_size}, while `n_basis_funcs` is " + f"{self.n_basis_funcs}.") + else: + raise e + return self From aaf413b88ab48c580741f5728f3e0654c28c235e Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 9 Dec 2024 12:39:14 -0500 Subject: [PATCH 11/37] added tests for fit and fit_transform on input types --- src/nemos/basis/_basis.py | 4 +- src/nemos/basis/_basis_mixin.py | 2 +- src/nemos/basis/_transformer_basis.py | 50 +- src/nemos/basis/basis.py | 12 +- tests/conftest.py | 103 +++ tests/test_basis.py | 999 ++++++++------------------ tests/test_transformer_basis.py | 660 +++++++++++++++++ 7 files changed, 1124 insertions(+), 706 deletions(-) create mode 100644 tests/test_transformer_basis.py diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index 3fdc9c59..3dd7f3c7 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -149,8 +149,8 @@ def __init__( self._check_n_basis_min() # specified only after inputs/input shapes are provided - self._n_basis_input_ = None - self._input_shape_ = None + self._n_basis_input_ = getattr(self, "_n_basis_input_", None) + self._input_shape_ = getattr(self, "_input_shape_", None) # set by set_kernel self.kernel_ = None diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index ad40b68e..423b9588 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -50,7 +50,7 @@ def _compute_features(self, *xi: NDArray): out = self._evaluate(*(np.reshape(x, (x.shape[0], -1)) for x in xi)) return np.reshape(out, (out.shape[0], -1)) - def set_basis(self, *xi: NDArray) -> Basis: + def setup_basis(self, *xi: NDArray) -> Basis: """ Set all basis states. diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index 12809858..165079a3 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -62,13 +62,14 @@ class TransformerBasis: """ def __init__(self, basis: Basis): + self._check_initialized(basis) self._basis = copy.deepcopy(basis) @staticmethod def _check_initialized(basis): if basis._n_basis_input_ is None: raise RuntimeError( - "TransformerBasis initialization failed: the provided basis has no defined input shape. " + "Cannot initialize TransformerBasis: the provided basis has no defined input shape. " "Please call `set_input_shape` on the basis before initializing the transformer." ) @@ -101,7 +102,7 @@ def _unpack_inputs(self, X: FeatureMatrix) -> List: n_samples = X.shape[0] out = [] cc = 0 - for i, bas in enumerate(self._basis._list_components()): + for i, bas in enumerate(self._list_components()): n_input = self._n_basis_input_[i] out.append( np.reshape(X[:, cc : cc + n_input], (n_samples, *bas._input_shape_)) @@ -113,7 +114,11 @@ def fit(self, X: FeatureMatrix, y=None): """ Compute the convolutional kernels. - If any of the 1D basis in self._basis is in "conv" mode, it computes the convolutional kernels. + Checks the input structure and, if any of the 1D basis in self._basis is in "conv" mode, + it computes the convolutional kernels. + + Note that the input must be 2-dimensional, and the number of column must match the number of inputs + that the basis expect. The number of input can be reset by calling the ``set_input_shape`` method. Parameters ---------- @@ -127,6 +132,11 @@ def fit(self, X: FeatureMatrix, y=None): self : The transformer object. + Raises + ------ + ValueError: + If the number of columns in X do not match the number of inputs that the basis expects. + Examples -------- >>> import numpy as np @@ -140,6 +150,7 @@ def fit(self, X: FeatureMatrix, y=None): >>> transformer = TransformerBasis(basis) >>> transformer_fitted = transformer.fit(X) """ + self._check_input(X, y) self._basis.set_kernel() return self @@ -219,6 +230,7 @@ def fit_transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: >>> # Fit and transform basis >>> feature_transformed = transformer.fit_transform(X) """ + self.fit(X, y=y) return self._basis.compute_features(*self._unpack_inputs(X)) def __getstate__(self): @@ -416,3 +428,35 @@ def __pow__(self, exponent: int) -> TransformerBasis: """ # errors are handled by Basis.__pow__ return TransformerBasis(self._basis**exponent) + + def _check_input(self, X: FeatureMatrix, y=None): + """Check that the input structure. + + TransformerBasis expects a 2-d array as an input. The number of columns should match the number of inputs + the basis expects. This number can be set before the TransformerBasis is initialized, by calling + ``Basis.set_input_shape``. + + Parameters + ---------- + X: + The input FeatureMatrix. + + Raises + ------ + ValueError: + If the input is not a 2-d array or if the number of columns does not match the expected number of inputs. + """ + ndim = getattr(X, "ndim", None) + if ndim is None or y is not None: + raise ValueError("The input must be a 2-dimensional array.") + + elif ndim != 2: + raise ValueError( + f"X must be 2-dimensional, shape (n_samples, n_features). The provided X has shape {X.shape} instead." + ) + + if X.shape[1] != sum(self.n_basis_input_): + raise ValueError( + f"Input mismatch: expected {sum(self.n_basis_input_)} inputs, but got {X.shape[1]} columns in X.\n" + "To modify the required number of inputs, call `set_input_shape` before using `fit` or `fit_transform`." + ) diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 9a144728..6feab33a 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -1851,11 +1851,13 @@ def set_kernel(self): super().set_kernel() except ValueError as e: if "OrthExponentialBasis requires at least as many" in str(e): - raise ValueError("Cannot set the kernels for OrthExponentialBasis when `window_size` is smaller " - "than `n_basis_funcs.\n" - "Please, increase the window size or reduce the number of basis functions. " - f"Current `window_size` is {self.window_size}, while `n_basis_funcs` is " - f"{self.n_basis_funcs}.") + raise ValueError( + "Cannot set the kernels for OrthExponentialBasis when `window_size` is smaller " + "than `n_basis_funcs.\n" + "Please, increase the window size or reduce the number of basis functions. " + f"Current `window_size` is {self.window_size}, while `n_basis_funcs` is " + f"{self.n_basis_funcs}." + ) else: raise e return self diff --git a/tests/conftest.py b/tests/conftest.py index eb88ed10..3daba960 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,8 @@ and loading predefined parameters for testing various functionalities of the NeMoS library. """ +import abc + import jax import jax.numpy as jnp import numpy as np @@ -16,11 +18,112 @@ import pytest import nemos as nmo +import nemos._inspect_utils as inspect_utils +import nemos.basis.basis as basis +from nemos.basis import AdditiveBasis, MultiplicativeBasis +from nemos.basis._basis import Basis # shut-off conversion warnings nap.nap_config.suppress_conversion_warnings = True +@pytest.fixture() +def basis_class_specific_params(): + """Returns all the params for each class.""" + all_cls = list_all_basis_classes("Conv") + list_all_basis_classes("Eval") + return {cls.__name__: cls._get_param_names() for cls in all_cls} + + +class BasisFuncsTesting(abc.ABC): + """ + An abstract base class that sets the foundation for individual basis function testing. + This class requires an implementation of a 'cls' method, which is utilized by the meta-test + that verifies if all basis functions are properly tested. + """ + + @abc.abstractmethod + def cls(self): + pass + + +class CombinedBasis(BasisFuncsTesting): + """ + This class is used to run tests on combination operations (e.g., addition, multiplication) among Basis functions. + + Properties: + - cls: Class (default = None) + """ + + cls = None + + @staticmethod + def instantiate_basis( + n_basis, basis_class, class_specific_params, window_size=10, **kwargs + ): + """Instantiate and return two basis of the type specified.""" + + # Set non-optional args + default_kwargs = { + "n_basis_funcs": n_basis, + "window_size": window_size, + "decay_rates": np.arange(1, 1 + n_basis), + } + repeated_keys = set(default_kwargs.keys()).intersection(kwargs.keys()) + if repeated_keys: + raise ValueError( + "Cannot set `n_basis_funcs, window_size, decay_rates` with kwargs" + ) + + # Merge with provided extra kwargs + kwargs = {**default_kwargs, **kwargs} + + if basis_class == AdditiveBasis: + kwargs_mspline = inspect_utils.trim_kwargs( + basis.MSplineEval, kwargs, class_specific_params + ) + kwargs_raised_cosine = inspect_utils.trim_kwargs( + basis.RaisedCosineLinearConv, kwargs, class_specific_params + ) + b1 = basis.MSplineEval(**kwargs_mspline) + b2 = basis.RaisedCosineLinearConv(**kwargs_raised_cosine) + basis_obj = b1 + b2 + elif basis_class == MultiplicativeBasis: + kwargs_mspline = inspect_utils.trim_kwargs( + basis.MSplineEval, kwargs, class_specific_params + ) + kwargs_raised_cosine = inspect_utils.trim_kwargs( + basis.RaisedCosineLinearConv, kwargs, class_specific_params + ) + b1 = basis.MSplineEval(**kwargs_mspline) + b2 = basis.RaisedCosineLinearConv(**kwargs_raised_cosine) + basis_obj = b1 * b2 + else: + basis_obj = basis_class( + **inspect_utils.trim_kwargs(basis_class, kwargs, class_specific_params) + ) + return basis_obj + + +# automatic define user accessible basis and check the methods +def list_all_basis_classes(filter_basis="all") -> list[type]: + """ + Return all the classes in nemos.basis which are a subclass of Basis, + which should be all concrete classes except TransformerBasis. + """ + all_basis = [ + class_obj + for _, class_obj in inspect_utils.get_non_abstract_classes(basis) + if issubclass(class_obj, Basis) + ] + [ + bas + for _, bas in inspect_utils.get_non_abstract_classes(nmo.basis._basis) + if bas != basis.TransformerBasis + ] + if filter_basis != "all": + all_basis = [a for a in all_basis if filter_basis in a.__name__] + return all_basis + + # Sample subclass to test instantiation and methods class MockRegressor(nmo.base_regressor.BaseRegressor): """ diff --git a/tests/test_basis.py b/tests/test_basis.py index b79102d7..55f4335f 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -1,7 +1,5 @@ -import abc import inspect import itertools -import pickle import re from contextlib import nullcontext as does_not_raise from functools import partial @@ -11,9 +9,8 @@ import numpy as np import pynapple as nap import pytest -from sklearn.base import clone as sk_clone +from conftest import BasisFuncsTesting, CombinedBasis, list_all_basis_classes -import nemos as nmo import nemos._inspect_utils as inspect_utils import nemos.basis.basis as basis import nemos.convolve as convolve @@ -34,33 +31,6 @@ def extra_decay_rates(cls, n_basis): return {} -# automatic define user accessible basis and check the methods -def list_all_basis_classes(filter_basis="all") -> list[type]: - """ - Return all the classes in nemos.basis which are a subclass of Basis, - which should be all concrete classes except TransformerBasis. - """ - all_basis = [ - class_obj - for _, class_obj in inspect_utils.get_non_abstract_classes(basis) - if issubclass(class_obj, Basis) - ] + [ - bas - for _, bas in inspect_utils.get_non_abstract_classes(nmo.basis._basis) - if bas != basis.TransformerBasis - ] - if filter_basis != "all": - all_basis = [a for a in all_basis if filter_basis in a.__name__] - return all_basis - - -@pytest.fixture() -def class_specific_params(): - """Returns all the params for each class.""" - all_cls = list_all_basis_classes("Conv") + list_all_basis_classes("Eval") - return {cls.__name__: cls._get_param_names() for cls in all_cls} - - def test_all_basis_are_tested() -> None: """Meta-test. @@ -137,11 +107,11 @@ def test_all_basis_are_tested() -> None: ], ) def test_example_docstrings_add( - basis_cls, method_name, descr_match, class_specific_params + basis_cls, method_name, descr_match, basis_class_specific_params ): basis_instance = CombinedBasis().instantiate_basis( - 5, basis_cls, class_specific_params, window_size=10 + 5, basis_cls, basis_class_specific_params, window_size=10 ) method = getattr(basis_instance, method_name) doc = method.__doc__ @@ -285,18 +255,6 @@ def test_expected_output_split_by_feature(basis_instance, super_class): np.testing.assert_array_equal(xx[~nans], x[~nans]) -class BasisFuncsTesting(abc.ABC): - """ - An abstract base class that sets the foundation for individual basis function testing. - This class requires an implementation of a 'cls' method, which is utilized by the meta-test - that verifies if all basis functions are properly tested. - """ - - @abc.abstractmethod - def cls(self): - pass - - @pytest.mark.parametrize( "cls", [ @@ -721,7 +679,7 @@ def test_compute_features_conv_input( order, width, cls, - class_specific_params, + basis_class_specific_params, ): x = np.ones(input_shape) @@ -736,7 +694,9 @@ def test_compute_features_conv_input( ) # figure out which kwargs needs to be removed - kwargs = inspect_utils.trim_kwargs(cls["conv"], kwargs, class_specific_params) + kwargs = inspect_utils.trim_kwargs( + cls["conv"], kwargs, basis_class_specific_params + ) basis_obj = cls["conv"](**kwargs) out = basis_obj.compute_features(x) @@ -1178,7 +1138,7 @@ def test_set_params( decay_rates, conv_kwargs, cls, - class_specific_params, + basis_class_specific_params, ): """Test the read-only and read/write property of the parameters.""" pars = dict( @@ -1195,7 +1155,7 @@ def test_set_params( pars = { key: value for key, value in pars.items() - if key in class_specific_params[cls[mode].__name__] + if key in basis_class_specific_params[cls[mode].__name__] } keys = list(pars.keys()) @@ -1960,72 +1920,16 @@ def test_samples_range_matches_compute_features_requirements( basis_obj.compute_features(np.linspace(*sample_range, 100)) -class CombinedBasis(BasisFuncsTesting): - """ - This class is used to run tests on combination operations (e.g., addition, multiplication) among Basis functions. - - Properties: - - cls: Class (default = None) - """ - - cls = None - - @staticmethod - def instantiate_basis( - n_basis, basis_class, class_specific_params, window_size=10, **kwargs - ): - """Instantiate and return two basis of the type specified.""" - - # Set non-optional args - default_kwargs = { - "n_basis_funcs": n_basis, - "window_size": window_size, - "decay_rates": np.arange(1, 1 + n_basis), - } - repeated_keys = set(default_kwargs.keys()).intersection(kwargs.keys()) - if repeated_keys: - raise ValueError( - "Cannot set `n_basis_funcs, window_size, decay_rates` with kwargs" - ) - - # Merge with provided extra kwargs - kwargs = {**default_kwargs, **kwargs} - - if basis_class == AdditiveBasis: - kwargs_mspline = inspect_utils.trim_kwargs( - basis.MSplineEval, kwargs, class_specific_params - ) - kwargs_raised_cosine = inspect_utils.trim_kwargs( - basis.RaisedCosineLinearConv, kwargs, class_specific_params - ) - b1 = basis.MSplineEval(**kwargs_mspline) - b2 = basis.RaisedCosineLinearConv(**kwargs_raised_cosine) - basis_obj = b1 + b2 - elif basis_class == MultiplicativeBasis: - kwargs_mspline = inspect_utils.trim_kwargs( - basis.MSplineEval, kwargs, class_specific_params - ) - kwargs_raised_cosine = inspect_utils.trim_kwargs( - basis.RaisedCosineLinearConv, kwargs, class_specific_params - ) - b1 = basis.MSplineEval(**kwargs_mspline) - b2 = basis.RaisedCosineLinearConv(**kwargs_raised_cosine) - basis_obj = b1 * b2 - else: - basis_obj = basis_class( - **inspect_utils.trim_kwargs(basis_class, kwargs, class_specific_params) - ) - return basis_obj - - class TestAdditiveBasis(CombinedBasis): cls = {"eval": AdditiveBasis, "conv": AdditiveBasis} @pytest.mark.parametrize("samples", [[[0], []], [[], [0]], [[0, 0], [0, 0]]]) @pytest.mark.parametrize("base_cls", [basis.BSplineEval, basis.BSplineConv]) - def test_non_empty_samples(self, base_cls, samples, class_specific_params): + def test_non_empty_samples(self, base_cls, samples, basis_class_specific_params): kwargs = {"window_size": 2, "n_basis_funcs": 5} - kwargs = inspect_utils.trim_kwargs(base_cls, kwargs, class_specific_params) + kwargs = inspect_utils.trim_kwargs( + base_cls, kwargs, basis_class_specific_params + ) basis_obj = base_cls(**kwargs) + base_cls(**kwargs) if any(tuple(len(s) == 0 for s in samples)): with pytest.raises( @@ -2066,7 +1970,7 @@ def test_compute_features_returns_expected_number_of_basis( basis_a, basis_b, window_size, - class_specific_params, + basis_class_specific_params, ): """ Test whether the evaluation of the `AdditiveBasis` results in a number of basis @@ -2074,10 +1978,10 @@ def test_compute_features_returns_expected_number_of_basis( """ # define the two basis basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj @@ -2105,16 +2009,16 @@ def test_sample_size_of_compute_features_matches_that_of_input( basis_a, basis_b, window_size, - class_specific_params, + basis_class_specific_params, ): """ Test whether the output sample size from `AdditiveBasis` compute_features function matches input sample size. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj eval_basis = basis_obj.compute_features( @@ -2142,17 +2046,17 @@ def test_number_of_required_inputs_compute_features( basis_a, basis_b, window_size, - class_specific_params, + basis_class_specific_params, ): """ Test whether the number of required inputs for the `compute_features` function matches the sum of the number of input samples from the two bases. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj required_dim = ( @@ -2174,16 +2078,22 @@ def test_number_of_required_inputs_compute_features( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_meshgrid_size( - self, sample_size, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, + sample_size, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + basis_class_specific_params, ): """ Test whether the resulting meshgrid size matches the sample size input. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) basis_obj = basis_a_obj + basis_b_obj res = basis_obj.evaluate_on_grid( @@ -2198,16 +2108,22 @@ def test_evaluate_on_grid_meshgrid_size( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_basis_size( - self, sample_size, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, + sample_size, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + basis_class_specific_params, ): """ Test whether the number sample size output by evaluate_on_grid matches the sample size of the input. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) basis_obj = basis_a_obj + basis_b_obj eval_basis = basis_obj.evaluate_on_grid( @@ -2221,17 +2137,23 @@ def test_evaluate_on_grid_basis_size( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_input_number( - self, n_input, basis_a, basis_b, n_basis_a, n_basis_b, class_specific_params + self, + n_input, + basis_a, + basis_b, + n_basis_a, + n_basis_b, + basis_class_specific_params, ): """ Test whether the number of inputs provided to `evaluate_on_grid` matches the sum of the number of input samples required from each of the basis objects. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) basis_obj = basis_a_obj + basis_b_obj inputs = [20] * n_input @@ -2253,7 +2175,13 @@ def test_evaluate_on_grid_input_number( @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) def test_pynapple_support_compute_features( - self, basis_a, basis_b, n_basis_a, n_basis_b, sample_size, class_specific_params + self, + basis_a, + basis_b, + n_basis_a, + n_basis_b, + sample_size, + basis_class_specific_params, ): iset = nap.IntervalSet(start=[0, 0.5], end=[0.49999, 1]) inp = nap.Tsd( @@ -2262,9 +2190,9 @@ def test_pynapple_support_compute_features( time_support=iset, ) basis_add = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) + self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) # compute_features the basis over pynapple Tsd objects out = basis_add.compute_features(*([inp] * basis_add._n_input_dimensionality)) @@ -2288,13 +2216,13 @@ def test_call_input_num( basis_b, num_input, window_size, - class_specific_params, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj if num_input == basis_obj._n_input_dimensionality: @@ -2327,13 +2255,13 @@ def test_call_input_shape( inp, window_size, expectation, - class_specific_params, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj with expectation: @@ -2353,13 +2281,13 @@ def test_call_sample_axis( basis_b, time_axis_shape, window_size, - class_specific_params, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj inp = [np.linspace(0, 1, time_axis_shape)] * basis_obj._n_input_dimensionality @@ -2371,7 +2299,13 @@ def test_call_sample_axis( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_nan( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): if ( basis_a == basis.OrthExponentialBasis @@ -2379,10 +2313,10 @@ def test_call_nan( ): return basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj inp = [np.linspace(0, 1, 10)] * basis_obj._n_input_dimensionality @@ -2395,21 +2329,21 @@ def test_call_nan( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_equivalent_in_conv( - self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, basis_class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=9 + n_basis_a, basis_a, basis_class_specific_params, window_size=9 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=9 + n_basis_b, basis_b, basis_class_specific_params, window_size=9 ) bas_eva = basis_a_obj + basis_b_obj basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=8 + n_basis_a, basis_a, basis_class_specific_params, window_size=8 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=8 + n_basis_b, basis_b, basis_class_specific_params, window_size=8 ) bas_con = basis_a_obj + basis_b_obj @@ -2422,13 +2356,19 @@ def test_call_equivalent_in_conv( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_pynapple_support( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj + basis_b_obj x = np.linspace(0, 1, 10) @@ -2446,13 +2386,19 @@ def test_pynapple_support( @pytest.mark.parametrize("n_basis_a", [6, 7]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_basis_number( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj + basis_b_obj x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality @@ -2467,13 +2413,19 @@ def test_call_basis_number( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_non_empty( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj + basis_b_obj with pytest.raises(ValueError, match="All sample provided must"): @@ -2502,7 +2454,7 @@ def test_call_sample_range( mx, expectation, window_size, - class_specific_params, + basis_class_specific_params, ): if expectation == "check": if ( @@ -2515,10 +2467,10 @@ def test_call_sample_range( else: expectation = does_not_raise() basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj + basis_b_obj with expectation: @@ -2529,13 +2481,13 @@ def test_call_sample_range( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_fit_kernel( - self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, basis_class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) bas = basis_a_obj + basis_b_obj bas.set_kernel() @@ -2558,13 +2510,13 @@ def check_kernel(basis_obj): @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_transform_fails( - self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, basis_class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) bas = basis_a_obj + basis_b_obj if "Eval" in basis_a.__name__ and "Eval" in basis_b.__name__: @@ -2636,16 +2588,16 @@ def test_set_input_shape_type_1d_arrays( basis_b, shape_a, shape_b, - class_specific_params, + basis_class_specific_params, add_shape_a, add_shape_b, ): x = (np.ones((10, *add_shape_a)), np.ones((10, *add_shape_b))) basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) add = basis_a + basis_b @@ -2675,16 +2627,16 @@ def test_set_input_shape_type_2d_arrays( basis_b, shape_a, shape_b, - class_specific_params, + basis_class_specific_params, add_shape_a, add_shape_b, ): x = (np.ones((10, 2, *add_shape_a)), np.ones((10, 3, *add_shape_b))) basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) add = basis_a + basis_b @@ -2714,16 +2666,16 @@ def test_set_input_shape_type_nd_arrays( basis_b, shape_a, shape_b, - class_specific_params, + basis_class_specific_params, add_shape_a, add_shape_b, ): x = (np.ones((10, 2, 2, *add_shape_a)), np.ones((10, 3, 1, *add_shape_b))) basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) add = basis_a + basis_b @@ -2762,13 +2714,13 @@ def test_set_input_shape_type_nd_arrays( "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") ) def test_set_input_value_types( - self, inp_shape, expectation, basis_a, basis_b, class_specific_params + self, inp_shape, expectation, basis_a, basis_b, basis_class_specific_params ): basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) add = basis_a + basis_b with expectation: @@ -2780,12 +2732,12 @@ def test_set_input_value_types( @pytest.mark.parametrize( "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") ) - def test_deep_copy_basis(self, basis_a, basis_b, class_specific_params): + def test_deep_copy_basis(self, basis_a, basis_b, basis_class_specific_params): basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) add = basis_a + basis_b # test pointing to different objects @@ -2811,12 +2763,14 @@ def test_deep_copy_basis(self, basis_a, basis_b, class_specific_params): @pytest.mark.parametrize( "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") ) - def test_compute_n_basis_runtime(self, basis_a, basis_b, class_specific_params): + def test_compute_n_basis_runtime( + self, basis_a, basis_b, basis_class_specific_params + ): basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) add = basis_a + basis_b add.basis1.n_basis_funcs = 10 @@ -2826,15 +2780,17 @@ def test_compute_n_basis_runtime(self, basis_a, basis_b, class_specific_params): @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) - def test_runtime_n_basis_out_compute(self, basis_a, basis_b, class_specific_params): + def test_runtime_n_basis_out_compute( + self, basis_a, basis_b, basis_class_specific_params + ): basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_a.set_input_shape( *([1] * basis_a._n_input_dimensionality) ).to_transformer() basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) basis_b.set_input_shape( *([1] * basis_b._n_input_dimensionality) @@ -2904,7 +2860,7 @@ def test_compute_features_returns_expected_number_of_basis( basis_a, basis_b, window_size, - class_specific_params, + basis_class_specific_params, ): """ Test whether the evaluation of the `MultiplicativeBasis` results in a number of basis @@ -2912,10 +2868,10 @@ def test_compute_features_returns_expected_number_of_basis( """ # define the two basis basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj @@ -2944,17 +2900,17 @@ def test_sample_size_of_compute_features_matches_that_of_input( basis_a, basis_b, window_size, - class_specific_params, + basis_class_specific_params, ): """ Test whether the output sample size from the `MultiplicativeBasis` fit_transform function matches the input sample size. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj eval_basis = basis_obj.compute_features( @@ -2981,17 +2937,17 @@ def test_number_of_required_inputs_compute_features( basis_a, basis_b, window_size, - class_specific_params, + basis_class_specific_params, ): """ Test whether the number of required inputs for the `compute_features` function matches the sum of the number of input samples from the two bases. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj required_dim = ( @@ -3013,16 +2969,22 @@ def test_number_of_required_inputs_compute_features( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_meshgrid_size( - self, sample_size, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, + sample_size, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + basis_class_specific_params, ): """ Test whether the resulting meshgrid size matches the sample size input. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) basis_obj = basis_a_obj * basis_b_obj res = basis_obj.evaluate_on_grid( @@ -3037,16 +2999,22 @@ def test_evaluate_on_grid_meshgrid_size( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_basis_size( - self, sample_size, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, + sample_size, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + basis_class_specific_params, ): """ Test whether the number sample size output by evaluate_on_grid matches the sample size of the input. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) basis_obj = basis_a_obj * basis_b_obj eval_basis = basis_obj.evaluate_on_grid( @@ -3060,17 +3028,23 @@ def test_evaluate_on_grid_basis_size( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_input_number( - self, n_input, basis_a, basis_b, n_basis_a, n_basis_b, class_specific_params + self, + n_input, + basis_a, + basis_b, + n_basis_a, + n_basis_b, + basis_class_specific_params, ): """ Test whether the number of inputs provided to `evaluate_on_grid` matches the sum of the number of input samples required from each of the basis objects. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) basis_obj = basis_a_obj * basis_b_obj inputs = [20] * n_input @@ -3100,15 +3074,15 @@ def test_inconsistent_sample_sizes( n_basis_b, sample_size_a, sample_size_b, - class_specific_params, + basis_class_specific_params, ): """Test that the inputs of inconsistent sample sizes result in an exception when compute_features is called""" raise_exception = sample_size_a != sample_size_b basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) input_a = [ np.linspace(0, 1, sample_size_a) @@ -3132,7 +3106,13 @@ def test_inconsistent_sample_sizes( @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) def test_pynapple_support_compute_features( - self, basis_a, basis_b, n_basis_a, n_basis_b, sample_size, class_specific_params + self, + basis_a, + basis_b, + n_basis_a, + n_basis_b, + sample_size, + basis_class_specific_params, ): iset = nap.IntervalSet(start=[0, 0.5], end=[0.49999, 1]) inp = nap.Tsd( @@ -3141,9 +3121,9 @@ def test_pynapple_support_compute_features( time_support=iset, ) basis_prod = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) * self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) out = basis_prod.compute_features(*([inp] * basis_prod._n_input_dimensionality)) assert isinstance(out, nap.TsdFrame) @@ -3164,13 +3144,13 @@ def test_call_input_num( basis_b, num_input, window_size, - class_specific_params, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj if num_input == basis_obj._n_input_dimensionality: @@ -3203,13 +3183,13 @@ def test_call_input_shape( inp, window_size, expectation, - class_specific_params, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj with expectation: @@ -3229,13 +3209,13 @@ def test_call_sample_axis( basis_b, time_axis_shape, window_size, - class_specific_params, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj inp = [np.linspace(0, 1, time_axis_shape)] * basis_obj._n_input_dimensionality @@ -3247,7 +3227,13 @@ def test_call_sample_axis( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_nan( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): if ( basis_a == basis.OrthExponentialBasis @@ -3255,10 +3241,10 @@ def test_call_nan( ): return basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj inp = [np.linspace(0, 1, 10)] * basis_obj._n_input_dimensionality @@ -3271,21 +3257,21 @@ def test_call_nan( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_equivalent_in_conv( - self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, basis_class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) bas_eva = basis_a_obj * basis_b_obj basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=8 + n_basis_a, basis_a, basis_class_specific_params, window_size=8 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=8 + n_basis_b, basis_b, basis_class_specific_params, window_size=8 ) bas_con = basis_a_obj * basis_b_obj @@ -3298,13 +3284,19 @@ def test_call_equivalent_in_conv( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_pynapple_support( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj * basis_b_obj x = np.linspace(0, 1, 10) @@ -3322,13 +3314,19 @@ def test_pynapple_support( @pytest.mark.parametrize("n_basis_a", [6, 7]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_basis_number( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj * basis_b_obj x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality @@ -3343,13 +3341,19 @@ def test_call_basis_number( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_non_empty( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj * basis_b_obj with pytest.raises(ValueError, match="All sample provided must"): @@ -3378,7 +3382,7 @@ def test_call_sample_range( mx, expectation, window_size, - class_specific_params, + basis_class_specific_params, ): if expectation == "check": if ( @@ -3391,10 +3395,10 @@ def test_call_sample_range( else: expectation = does_not_raise() basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj * basis_b_obj with expectation: @@ -3405,13 +3409,13 @@ def test_call_sample_range( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_fit_kernel( - self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, basis_class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) bas = basis_a_obj * basis_b_obj bas.set_kernel() @@ -3434,13 +3438,13 @@ def check_kernel(basis_obj): @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_transform_fails( - self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, basis_class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) bas = basis_a_obj * basis_b_obj if "Eval" in basis_a.__name__ and "Eval" in basis_b.__name__: @@ -3523,16 +3527,16 @@ def test_set_input_shape_type_1d_arrays( basis_b, shape_a, shape_b, - class_specific_params, + basis_class_specific_params, add_shape_a, add_shape_b, ): x = (np.ones((10, *add_shape_a)), np.ones((10, *add_shape_b))) basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) mul = basis_a * basis_b @@ -3562,16 +3566,16 @@ def test_set_input_shape_type_2d_arrays( basis_b, shape_a, shape_b, - class_specific_params, + basis_class_specific_params, add_shape_a, add_shape_b, ): x = (np.ones((10, 2, *add_shape_a)), np.ones((10, 3, *add_shape_b))) basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) mul = basis_a * basis_b @@ -3601,16 +3605,16 @@ def test_set_input_shape_type_nd_arrays( basis_b, shape_a, shape_b, - class_specific_params, + basis_class_specific_params, add_shape_a, add_shape_b, ): x = (np.ones((10, 2, 2, *add_shape_a)), np.ones((10, 3, 1, *add_shape_b))) basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) mul = basis_a * basis_b @@ -3649,13 +3653,13 @@ def test_set_input_shape_type_nd_arrays( "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") ) def test_set_input_value_types( - self, inp_shape, expectation, basis_a, basis_b, class_specific_params + self, inp_shape, expectation, basis_a, basis_b, basis_class_specific_params ): basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) mul = basis_a * basis_b with expectation: @@ -3667,12 +3671,12 @@ def test_set_input_value_types( @pytest.mark.parametrize( "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") ) - def test_deep_copy_basis(self, basis_a, basis_b, class_specific_params): + def test_deep_copy_basis(self, basis_a, basis_b, basis_class_specific_params): basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) mul = basis_a * basis_b # test pointing to different objects @@ -3698,12 +3702,14 @@ def test_deep_copy_basis(self, basis_a, basis_b, class_specific_params): @pytest.mark.parametrize( "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") ) - def test_compute_n_basis_runtime(self, basis_a, basis_b, class_specific_params): + def test_compute_n_basis_runtime( + self, basis_a, basis_b, basis_class_specific_params + ): basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) mul = basis_a * basis_b mul.basis1.n_basis_funcs = 10 @@ -3713,15 +3719,17 @@ def test_compute_n_basis_runtime(self, basis_a, basis_b, class_specific_params): @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) - def test_runtime_n_basis_out_compute(self, basis_a, basis_b, class_specific_params): + def test_runtime_n_basis_out_compute( + self, basis_a, basis_b, basis_class_specific_params + ): basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_a.set_input_shape( *([1] * basis_a._n_input_dimensionality) ).to_transformer() basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) basis_b.set_input_shape( *([1] * basis_b._n_input_dimensionality) @@ -3747,7 +3755,7 @@ def test_runtime_n_basis_out_compute(self, basis_a, basis_b, class_specific_para "exponent", [-1, 0, 0.5, basis.RaisedCosineLogEval(4), 1, 2, 3] ) @pytest.mark.parametrize("basis_class", list_all_basis_classes()) -def test_power_of_basis(exponent, basis_class, class_specific_params): +def test_power_of_basis(exponent, basis_class, basis_class_specific_params): """Test if the power behaves as expected.""" raise_exception_type = not type(exponent) is int @@ -3757,7 +3765,7 @@ def test_power_of_basis(exponent, basis_class, class_specific_params): raise_exception_value = False basis_obj = CombinedBasis.instantiate_basis( - 5, basis_class, class_specific_params, window_size=10 + 5, basis_class, basis_class_specific_params, window_size=10 ) if raise_exception_type: @@ -3793,10 +3801,10 @@ def test_power_of_basis(exponent, basis_class, class_specific_params): "basis_cls", list_all_basis_classes(), ) -def test_basis_to_transformer(basis_cls, class_specific_params): +def test_basis_to_transformer(basis_cls, basis_class_specific_params): n_basis_funcs = 5 bas = CombinedBasis().instantiate_basis( - n_basis_funcs, basis_cls, class_specific_params, window_size=10 + n_basis_funcs, basis_cls, basis_class_specific_params, window_size=10 ) trans_bas = bas.set_input_shape( *([1] * bas._n_input_dimensionality) @@ -3812,413 +3820,6 @@ def test_basis_to_transformer(basis_cls, class_specific_params): assert np.all(getattr(bas, k) == getattr(trans_bas, k)) -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -def test_transformer_has_the_same_public_attributes_as_basis( - basis_cls, class_specific_params -): - n_basis_funcs = 5 - bas = CombinedBasis().instantiate_basis( - n_basis_funcs, basis_cls, class_specific_params, window_size=10 - ) - - public_attrs_basis = {attr for attr in dir(bas) if not attr.startswith("_")} - public_attrs_transformerbasis = { - attr - for attr in dir( - bas.set_input_shape(*([1] * bas._n_input_dimensionality)).to_transformer() - ) - if not attr.startswith("_") - } - - assert public_attrs_transformerbasis - public_attrs_basis == { - "fit", - "fit_transform", - "transform", - "basis", - } - - assert public_attrs_basis - public_attrs_transformerbasis == set() - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), -) -def test_to_transformer_and_constructor_are_equivalent( - basis_cls, class_specific_params -): - n_basis_funcs = 5 - bas = CombinedBasis().instantiate_basis( - n_basis_funcs, basis_cls, class_specific_params, window_size=10 - ) - bas.set_input_shape(*([1] * bas._n_input_dimensionality)) - trans_bas_a = bas.to_transformer() - trans_bas_b = basis.TransformerBasis(bas) - - # they both just have a _basis - assert ( - list(trans_bas_a.__dict__.keys()) - == list(trans_bas_b.__dict__.keys()) - == ["_basis"] - ) - # and those bases are the same - assert np.all( - trans_bas_a._basis.__dict__.pop("_decay_rates", 1) - == trans_bas_b._basis.__dict__.pop("_decay_rates", 1) - ) - assert trans_bas_a._basis.__dict__ == trans_bas_b._basis.__dict__ - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -def test_basis_to_transformer_makes_a_copy(basis_cls, class_specific_params): - bas_a = CombinedBasis().instantiate_basis( - 5, basis_cls, class_specific_params, window_size=10 - ) - trans_bas_a = bas_a.set_input_shape( - *([1] * bas_a._n_input_dimensionality) - ).to_transformer() - - # changing an attribute in bas should not change trans_bas - if basis_cls in [AdditiveBasis, MultiplicativeBasis]: - bas_a._basis1.n_basis_funcs = 10 - assert trans_bas_a._basis._basis1.n_basis_funcs == 5 - - # changing an attribute in the transformer basis should not change the original - bas_b = CombinedBasis().instantiate_basis( - 5, basis_cls, class_specific_params, window_size=10 - ) - bas_b.set_input_shape(*([1] * bas_b._n_input_dimensionality)) - trans_bas_b = bas_b.to_transformer() - trans_bas_b._basis._basis1.n_basis_funcs = 100 - assert bas_b._basis1.n_basis_funcs == 5 - else: - bas_a.n_basis_funcs = 10 - assert trans_bas_a.n_basis_funcs == 5 - - # changing an attribute in the transformer basis should not change the original - bas_b = CombinedBasis().instantiate_basis( - 5, basis_cls, class_specific_params, window_size=10 - ) - trans_bas_b = bas_b.set_input_shape( - *([1] * bas_b._n_input_dimensionality) - ).to_transformer() - trans_bas_b.n_basis_funcs = 100 - assert bas_b.n_basis_funcs == 5 - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -@pytest.mark.parametrize("n_basis_funcs", [5, 10, 20]) -def test_transformerbasis_getattr(basis_cls, n_basis_funcs, class_specific_params): - bas = CombinedBasis().instantiate_basis( - n_basis_funcs, basis_cls, class_specific_params, window_size=30 - ) - trans_basis = basis.TransformerBasis( - bas.set_input_shape(*([1] * bas._n_input_dimensionality)) - ) - if basis_cls in [AdditiveBasis, MultiplicativeBasis]: - for bas in [ - getattr(trans_basis._basis, attr) for attr in ("_basis1", "_basis2") - ]: - assert bas.n_basis_funcs == n_basis_funcs - else: - assert trans_basis.n_basis_funcs == n_basis_funcs - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), -) -@pytest.mark.parametrize("n_basis_funcs_init", [5]) -@pytest.mark.parametrize("n_basis_funcs_new", [6, 10, 20]) -def test_transformerbasis_set_params( - basis_cls, n_basis_funcs_init, n_basis_funcs_new, class_specific_params -): - bas = CombinedBasis().instantiate_basis( - n_basis_funcs_init, basis_cls, class_specific_params, window_size=10 - ) - trans_basis = basis.TransformerBasis( - bas.set_input_shape(*([1] * bas._n_input_dimensionality)) - ) - trans_basis.set_params(n_basis_funcs=n_basis_funcs_new) - - assert trans_basis.n_basis_funcs == n_basis_funcs_new - assert trans_basis._basis.n_basis_funcs == n_basis_funcs_new - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), -) -def test_transformerbasis_setattr_basis(basis_cls, class_specific_params): - # setting the _basis attribute should change it - bas = CombinedBasis().instantiate_basis( - 10, basis_cls, class_specific_params, window_size=30 - ) - trans_bas = basis.TransformerBasis( - bas.set_input_shape(*([1] * bas._n_input_dimensionality)) - ) - - bas = CombinedBasis().instantiate_basis( - 20, basis_cls, class_specific_params, window_size=30 - ) - - trans_bas.basis = bas.set_input_shape(*([1] * bas._n_input_dimensionality)) - - assert trans_bas.n_basis_funcs == 20 - assert trans_bas._basis.n_basis_funcs == 20 - assert isinstance(trans_bas._basis, basis_cls) - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), -) -def test_transformerbasis_setattr_basis_attribute(basis_cls, class_specific_params): - # setting an attribute that is an attribute of the underlying _basis - # should propagate setting it on _basis itself - bas = CombinedBasis().instantiate_basis( - 10, basis_cls, class_specific_params, window_size=10 - ) - trans_bas = basis.TransformerBasis( - bas.set_input_shape(*([1] * bas._n_input_dimensionality)) - ) - trans_bas.n_basis_funcs = 20 - - assert trans_bas.n_basis_funcs == 20 - assert trans_bas._basis.n_basis_funcs == 20 - assert isinstance(trans_bas._basis, basis_cls) - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), -) -def test_transformerbasis_copy_basis_on_contsruct(basis_cls, class_specific_params): - # modifying the transformerbasis's attributes shouldn't - # touch the original basis that was used to create it - orig_bas = CombinedBasis().instantiate_basis( - 10, basis_cls, class_specific_params, window_size=10 - ) - orig_bas = orig_bas.set_input_shape(*([1] * orig_bas._n_input_dimensionality)) - trans_bas = basis.TransformerBasis(orig_bas) - trans_bas.n_basis_funcs = 20 - - assert orig_bas.n_basis_funcs == 10 - assert trans_bas._basis.n_basis_funcs == 20 - assert trans_bas._basis.n_basis_funcs == 20 - assert isinstance(trans_bas._basis, basis_cls) - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -def test_transformerbasis_setattr_illegal_attribute(basis_cls, class_specific_params): - # changing an attribute that is not _basis or an attribute of _basis - # is not allowed - bas = CombinedBasis().instantiate_basis( - 10, basis_cls, class_specific_params, window_size=10 - ) - trans_bas = basis.TransformerBasis( - bas.set_input_shape(*([1] * bas._n_input_dimensionality)) - ) - - with pytest.raises( - ValueError, - match="Only setting _basis or existing attributes of _basis is allowed.", - ): - trans_bas.random_attr = "random value" - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -def test_transformerbasis_addition(basis_cls, class_specific_params): - n_basis_funcs_a = 5 - n_basis_funcs_b = n_basis_funcs_a * 2 - bas_a = CombinedBasis().instantiate_basis( - n_basis_funcs_a, basis_cls, class_specific_params, window_size=10 - ) - bas_a.set_input_shape(*([1] * bas_a._n_input_dimensionality)) - bas_b = CombinedBasis().instantiate_basis( - n_basis_funcs_b, basis_cls, class_specific_params, window_size=10 - ) - bas_b.set_input_shape(*([1] * bas_b._n_input_dimensionality)) - trans_bas_a = basis.TransformerBasis(bas_a) - trans_bas_b = basis.TransformerBasis(bas_b) - trans_bas_sum = trans_bas_a + trans_bas_b - assert isinstance(trans_bas_sum, basis.TransformerBasis) - assert isinstance(trans_bas_sum._basis, AdditiveBasis) - assert ( - trans_bas_sum.n_basis_funcs - == trans_bas_a.n_basis_funcs + trans_bas_b.n_basis_funcs - ) - assert ( - trans_bas_sum._n_input_dimensionality - == trans_bas_a._n_input_dimensionality + trans_bas_b._n_input_dimensionality - ) - if basis_cls not in [AdditiveBasis, MultiplicativeBasis]: - assert trans_bas_sum._basis1.n_basis_funcs == n_basis_funcs_a - assert trans_bas_sum._basis2.n_basis_funcs == n_basis_funcs_b - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -def test_transformerbasis_multiplication(basis_cls, class_specific_params): - n_basis_funcs_a = 5 - n_basis_funcs_b = n_basis_funcs_a * 2 - bas1 = CombinedBasis().instantiate_basis( - n_basis_funcs_a, basis_cls, class_specific_params, window_size=10 - ) - trans_bas_a = basis.TransformerBasis( - bas1.set_input_shape(*([1] * bas1._n_input_dimensionality)) - ) - bas2 = CombinedBasis().instantiate_basis( - n_basis_funcs_b, basis_cls, class_specific_params, window_size=10 - ) - trans_bas_b = basis.TransformerBasis( - bas2.set_input_shape(*([1] * bas2._n_input_dimensionality)) - ) - trans_bas_prod = trans_bas_a * trans_bas_b - assert isinstance(trans_bas_prod, basis.TransformerBasis) - assert isinstance(trans_bas_prod._basis, MultiplicativeBasis) - assert ( - trans_bas_prod.n_basis_funcs - == trans_bas_a.n_basis_funcs * trans_bas_b.n_basis_funcs - ) - assert ( - trans_bas_prod._n_input_dimensionality - == trans_bas_a._n_input_dimensionality + trans_bas_b._n_input_dimensionality - ) - if basis_cls not in [AdditiveBasis, MultiplicativeBasis]: - assert trans_bas_prod._basis1.n_basis_funcs == n_basis_funcs_a - assert trans_bas_prod._basis2.n_basis_funcs == n_basis_funcs_b - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -@pytest.mark.parametrize( - "exponent, error_type, error_message", - [ - (2, does_not_raise, None), - (5, does_not_raise, None), - (0.5, TypeError, "Exponent should be an integer"), - (-1, ValueError, "Exponent should be a non-negative integer"), - ], -) -def test_transformerbasis_exponentiation( - basis_cls, exponent: int, error_type, error_message, class_specific_params -): - bas = CombinedBasis().instantiate_basis( - 5, basis_cls, class_specific_params, window_size=10 - ) - trans_bas = basis.TransformerBasis( - bas.set_input_shape(*([1] * bas._n_input_dimensionality)) - ) - - if not isinstance(exponent, int): - with pytest.raises(error_type, match=error_message): - trans_bas_exp = trans_bas**exponent - assert isinstance(trans_bas_exp, basis.TransformerBasis) - assert isinstance(trans_bas_exp._basis, MultiplicativeBasis) - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -def test_transformerbasis_dir(basis_cls, class_specific_params): - bas = CombinedBasis().instantiate_basis( - 5, basis_cls, class_specific_params, window_size=10 - ) - trans_bas = basis.TransformerBasis( - bas.set_input_shape(*([1] * bas._n_input_dimensionality)) - ) - for attr_name in ( - "fit", - "transform", - "fit_transform", - "n_basis_funcs", - "mode", - "window_size", - ): - if ( - attr_name == "window_size" - and "Conv" not in trans_bas._basis.__class__.__name__ - ): - continue - assert attr_name in dir(trans_bas) - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes("Conv"), -) -def test_transformerbasis_sk_clone_kernel_noned(basis_cls, class_specific_params): - orig_bas = CombinedBasis().instantiate_basis( - 10, basis_cls, class_specific_params, window_size=20 - ) - orig_bas.set_input_shape(*([1] * orig_bas._n_input_dimensionality)) - trans_bas = basis.TransformerBasis(orig_bas) - - # kernel should be saved in the object after fit - trans_bas.fit(np.random.randn(100, 20)) - assert isinstance(trans_bas.kernel_, np.ndarray) - - # cloning should set kernel_ to None - trans_bas_clone = sk_clone(trans_bas) - - # the original object should still have kernel_ - assert isinstance(trans_bas.kernel_, np.ndarray) - # but the clone should not have one - assert trans_bas_clone.kernel_ is None - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -@pytest.mark.parametrize("n_basis_funcs", [5]) -def test_transformerbasis_pickle( - tmpdir, basis_cls, n_basis_funcs, class_specific_params -): - bas = CombinedBasis().instantiate_basis( - n_basis_funcs, basis_cls, class_specific_params, window_size=10 - ) - # the test that tries cross-validation with n_jobs = 2 already should test this - trans_bas = basis.TransformerBasis( - bas.set_input_shape(*([1] * bas._n_input_dimensionality)) - ) - filepath = tmpdir / "transformerbasis.pickle" - with open(filepath, "wb") as f: - pickle.dump(trans_bas, f) - with open(filepath, "rb") as f: - trans_bas2 = pickle.load(f) - - assert isinstance(trans_bas2, basis.TransformerBasis) - if basis_cls in [AdditiveBasis, MultiplicativeBasis]: - for bas in [ - getattr(trans_bas2._basis, attr) for attr in ("_basis1", "_basis2") - ]: - assert bas.n_basis_funcs == n_basis_funcs - else: - assert trans_bas2.n_basis_funcs == n_basis_funcs - - @pytest.mark.parametrize( "tsd", [ @@ -4257,7 +3858,7 @@ def test_multi_epoch_pynapple_basis( shift, predictor_causality, nan_index, - class_specific_params, + basis_class_specific_params, ): """Test nan location in multi-epoch pynapple tsd.""" kwargs = dict( @@ -4271,7 +3872,11 @@ def test_multi_epoch_pynapple_basis( else: nbasis = 5 bas = CombinedBasis().instantiate_basis( - nbasis, basis_cls, class_specific_params, window_size=window_size, **kwargs + nbasis, + basis_cls, + basis_class_specific_params, + window_size=window_size, + **kwargs, ) n_input = bas._n_input_dimensionality @@ -4324,7 +3929,7 @@ def test_multi_epoch_pynapple_basis_transformer( shift, predictor_causality, nan_index, - class_specific_params, + basis_class_specific_params, ): """Test nan location in multi-epoch pynapple tsd.""" kwargs = dict( @@ -4338,7 +3943,11 @@ def test_multi_epoch_pynapple_basis_transformer( nbasis = 5 bas = CombinedBasis().instantiate_basis( - nbasis, basis_cls, class_specific_params, window_size=window_size, **kwargs + nbasis, + basis_cls, + basis_class_specific_params, + window_size=window_size, + **kwargs, ) n_input = bas._n_input_dimensionality @@ -4436,7 +4045,7 @@ def test_multi_epoch_pynapple_basis_transformer( ], ) def test__get_splitter( - bas1, bas2, bas3, operator1, operator2, compute_slice, class_specific_params + bas1, bas2, bas3, operator1, operator2, compute_slice, basis_class_specific_params ): # skip nested if any( @@ -4450,19 +4059,19 @@ def test__get_splitter( combine_basis = CombinedBasis() bas1_instance = combine_basis.instantiate_basis( - n_basis[0], bas1, class_specific_params, window_size=10, label="1" + n_basis[0], bas1, basis_class_specific_params, window_size=10, label="1" ) bas1_instance.set_input_shape( *([n_input_basis[0]] * bas1_instance._n_input_dimensionality) ) bas2_instance = combine_basis.instantiate_basis( - n_basis[1], bas2, class_specific_params, window_size=10, label="2" + n_basis[1], bas2, basis_class_specific_params, window_size=10, label="2" ) bas2_instance.set_input_shape( *([n_input_basis[1]] * bas2_instance._n_input_dimensionality) ) bas3_instance = combine_basis.instantiate_basis( - n_basis[2], bas3, class_specific_params, window_size=10, label="3" + n_basis[2], bas3, basis_class_specific_params, window_size=10, label="3" ) bas3_instance.set_input_shape( *([n_input_basis[2]] * bas3_instance._n_input_dimensionality) @@ -4602,7 +4211,7 @@ def test__get_splitter_split_by_input( n_input_basis_1, n_input_basis_2, compute_slice, - class_specific_params, + basis_class_specific_params, ): # skip nested if any( @@ -4614,14 +4223,14 @@ def test__get_splitter_split_by_input( n_basis = [5, 6] combine_basis = CombinedBasis() bas1_instance = combine_basis.instantiate_basis( - n_basis[0], bas1, class_specific_params, window_size=10, label="1" + n_basis[0], bas1, basis_class_specific_params, window_size=10, label="1" ) bas1_instance.set_input_shape( *([n_input_basis_1] * bas1_instance._n_input_dimensionality) ) bas2_instance = combine_basis.instantiate_basis( - n_basis[1], bas2, class_specific_params, window_size=10, label="2" + n_basis[1], bas2, basis_class_specific_params, window_size=10, label="2" ) bas2_instance.set_input_shape( *([n_input_basis_2] * bas1_instance._n_input_dimensionality) @@ -4644,7 +4253,7 @@ def test__get_splitter_split_by_input( "bas1, bas2, bas3", list(itertools.product(*[list_all_basis_classes()] * 3)), ) -def test_duplicate_keys(bas1, bas2, bas3, class_specific_params): +def test_duplicate_keys(bas1, bas2, bas3, basis_class_specific_params): # skip nested if any( bas in (AdditiveBasis, MultiplicativeBasis, basis.TransformerBasis) @@ -4654,13 +4263,13 @@ def test_duplicate_keys(bas1, bas2, bas3, class_specific_params): combine_basis = CombinedBasis() bas1_instance = combine_basis.instantiate_basis( - 5, bas1, class_specific_params, window_size=10, label="label" + 5, bas1, basis_class_specific_params, window_size=10, label="label" ) bas2_instance = combine_basis.instantiate_basis( - 5, bas2, class_specific_params, window_size=10, label="label" + 5, bas2, basis_class_specific_params, window_size=10, label="label" ) bas3_instance = combine_basis.instantiate_basis( - 5, bas3, class_specific_params, window_size=10, label="label" + 5, bas3, basis_class_specific_params, window_size=10, label="label" ) bas_obj = bas1_instance + bas2_instance + bas3_instance @@ -4691,7 +4300,7 @@ def test_duplicate_keys(bas1, bas2, bas3, class_specific_params): ], ) def test_split_feature_axis( - bas1, bas2, x, axis, expectation, exp_shapes, class_specific_params + bas1, bas2, x, axis, expectation, exp_shapes, basis_class_specific_params ): # skip nested if any( @@ -4703,10 +4312,10 @@ def test_split_feature_axis( n_basis = [5, 6] combine_basis = CombinedBasis() bas1_instance = combine_basis.instantiate_basis( - n_basis[0], bas1, class_specific_params, window_size=10, label="1" + n_basis[0], bas1, basis_class_specific_params, window_size=10, label="1" ) bas2_instance = combine_basis.instantiate_basis( - n_basis[1], bas2, class_specific_params, window_size=10, label="2" + n_basis[1], bas2, basis_class_specific_params, window_size=10, label="2" ) bas = bas1_instance + bas2_instance diff --git a/tests/test_transformer_basis.py b/tests/test_transformer_basis.py new file mode 100644 index 00000000..aadd9174 --- /dev/null +++ b/tests/test_transformer_basis.py @@ -0,0 +1,660 @@ +import pickle +from contextlib import nullcontext as does_not_raise + +import numpy as np +import pytest +from conftest import CombinedBasis, list_all_basis_classes +from sklearn.base import clone as sk_clone + +from nemos import basis + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_transformer_has_the_same_public_attributes_as_basis( + basis_cls, basis_class_specific_params +): + n_basis_funcs = 5 + bas = CombinedBasis().instantiate_basis( + n_basis_funcs, basis_cls, basis_class_specific_params, window_size=10 + ) + + public_attrs_basis = {attr for attr in dir(bas) if not attr.startswith("_")} + public_attrs_transformerbasis = { + attr + for attr in dir( + bas.set_input_shape(*([1] * bas._n_input_dimensionality)).to_transformer() + ) + if not attr.startswith("_") + } + + assert public_attrs_transformerbasis - public_attrs_basis == { + "fit", + "fit_transform", + "transform", + "basis", + } + + assert public_attrs_basis - public_attrs_transformerbasis == set() + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), +) +def test_to_transformer_and_constructor_are_equivalent( + basis_cls, basis_class_specific_params +): + n_basis_funcs = 5 + bas = CombinedBasis().instantiate_basis( + n_basis_funcs, basis_cls, basis_class_specific_params, window_size=10 + ) + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) + trans_bas_a = bas.to_transformer() + trans_bas_b = basis.TransformerBasis(bas) + + # they both just have a _basis + assert ( + list(trans_bas_a.__dict__.keys()) + == list(trans_bas_b.__dict__.keys()) + == ["_basis"] + ) + # and those bases are the same + assert np.all( + trans_bas_a._basis.__dict__.pop("_decay_rates", 1) + == trans_bas_b._basis.__dict__.pop("_decay_rates", 1) + ) + assert trans_bas_a._basis.__dict__ == trans_bas_b._basis.__dict__ + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_basis_to_transformer_makes_a_copy(basis_cls, basis_class_specific_params): + bas_a = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + trans_bas_a = bas_a.set_input_shape( + *([1] * bas_a._n_input_dimensionality) + ).to_transformer() + + # changing an attribute in bas should not change trans_bas + if basis_cls in [basis.AdditiveBasis, basis.MultiplicativeBasis]: + bas_a._basis1.n_basis_funcs = 10 + assert trans_bas_a._basis._basis1.n_basis_funcs == 5 + + # changing an attribute in the transformer basis should not change the original + bas_b = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + bas_b.set_input_shape(*([1] * bas_b._n_input_dimensionality)) + trans_bas_b = bas_b.to_transformer() + trans_bas_b._basis._basis1.n_basis_funcs = 100 + assert bas_b._basis1.n_basis_funcs == 5 + else: + bas_a.n_basis_funcs = 10 + assert trans_bas_a.n_basis_funcs == 5 + + # changing an attribute in the transformer basis should not change the original + bas_b = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + trans_bas_b = bas_b.set_input_shape( + *([1] * bas_b._n_input_dimensionality) + ).to_transformer() + trans_bas_b.n_basis_funcs = 100 + assert bas_b.n_basis_funcs == 5 + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +@pytest.mark.parametrize("n_basis_funcs", [5, 10, 20]) +def test_transformerbasis_getattr( + basis_cls, n_basis_funcs, basis_class_specific_params +): + bas = CombinedBasis().instantiate_basis( + n_basis_funcs, basis_cls, basis_class_specific_params, window_size=30 + ) + trans_basis = basis.TransformerBasis( + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) + ) + if basis_cls in [basis.AdditiveBasis, basis.MultiplicativeBasis]: + for bas in [ + getattr(trans_basis._basis, attr) for attr in ("_basis1", "_basis2") + ]: + assert bas.n_basis_funcs == n_basis_funcs + else: + assert trans_basis.n_basis_funcs == n_basis_funcs + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), +) +@pytest.mark.parametrize("n_basis_funcs_init", [5]) +@pytest.mark.parametrize("n_basis_funcs_new", [6, 10, 20]) +def test_transformerbasis_set_params( + basis_cls, n_basis_funcs_init, n_basis_funcs_new, basis_class_specific_params +): + bas = CombinedBasis().instantiate_basis( + n_basis_funcs_init, basis_cls, basis_class_specific_params, window_size=10 + ) + trans_basis = basis.TransformerBasis( + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) + ) + trans_basis.set_params(n_basis_funcs=n_basis_funcs_new) + + assert trans_basis.n_basis_funcs == n_basis_funcs_new + assert trans_basis._basis.n_basis_funcs == n_basis_funcs_new + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), +) +def test_transformerbasis_setattr_basis(basis_cls, basis_class_specific_params): + # setting the _basis attribute should change it + bas = CombinedBasis().instantiate_basis( + 10, basis_cls, basis_class_specific_params, window_size=30 + ) + trans_bas = basis.TransformerBasis( + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) + ) + + bas = CombinedBasis().instantiate_basis( + 20, basis_cls, basis_class_specific_params, window_size=30 + ) + + trans_bas.basis = bas.set_input_shape(*([1] * bas._n_input_dimensionality)) + + assert trans_bas.n_basis_funcs == 20 + assert trans_bas._basis.n_basis_funcs == 20 + assert isinstance(trans_bas._basis, basis_cls) + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), +) +def test_transformerbasis_setattr_basis_attribute( + basis_cls, basis_class_specific_params +): + # setting an attribute that is an attribute of the underlying _basis + # should propagate setting it on _basis itself + bas = CombinedBasis().instantiate_basis( + 10, basis_cls, basis_class_specific_params, window_size=10 + ) + trans_bas = basis.TransformerBasis( + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) + ) + trans_bas.n_basis_funcs = 20 + + assert trans_bas.n_basis_funcs == 20 + assert trans_bas._basis.n_basis_funcs == 20 + assert isinstance(trans_bas._basis, basis_cls) + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), +) +def test_transformerbasis_copy_basis_on_contsruct( + basis_cls, basis_class_specific_params +): + # modifying the transformerbasis's attributes shouldn't + # touch the original basis that was used to create it + orig_bas = CombinedBasis().instantiate_basis( + 10, basis_cls, basis_class_specific_params, window_size=10 + ) + orig_bas = orig_bas.set_input_shape(*([1] * orig_bas._n_input_dimensionality)) + trans_bas = basis.TransformerBasis(orig_bas) + trans_bas.n_basis_funcs = 20 + + assert orig_bas.n_basis_funcs == 10 + assert trans_bas._basis.n_basis_funcs == 20 + assert trans_bas._basis.n_basis_funcs == 20 + assert isinstance(trans_bas._basis, basis_cls) + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_transformerbasis_setattr_illegal_attribute( + basis_cls, basis_class_specific_params +): + # changing an attribute that is not _basis or an attribute of _basis + # is not allowed + bas = CombinedBasis().instantiate_basis( + 10, basis_cls, basis_class_specific_params, window_size=10 + ) + trans_bas = basis.TransformerBasis( + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) + ) + + with pytest.raises( + ValueError, + match="Only setting _basis or existing attributes of _basis is allowed.", + ): + trans_bas.random_attr = "random value" + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_transformerbasis_addition(basis_cls, basis_class_specific_params): + n_basis_funcs_a = 5 + n_basis_funcs_b = n_basis_funcs_a * 2 + bas_a = CombinedBasis().instantiate_basis( + n_basis_funcs_a, basis_cls, basis_class_specific_params, window_size=10 + ) + bas_a.set_input_shape(*([1] * bas_a._n_input_dimensionality)) + bas_b = CombinedBasis().instantiate_basis( + n_basis_funcs_b, basis_cls, basis_class_specific_params, window_size=10 + ) + bas_b.set_input_shape(*([1] * bas_b._n_input_dimensionality)) + trans_bas_a = basis.TransformerBasis(bas_a) + trans_bas_b = basis.TransformerBasis(bas_b) + trans_bas_sum = trans_bas_a + trans_bas_b + assert isinstance(trans_bas_sum, basis.TransformerBasis) + assert isinstance(trans_bas_sum._basis, basis.AdditiveBasis) + assert ( + trans_bas_sum.n_basis_funcs + == trans_bas_a.n_basis_funcs + trans_bas_b.n_basis_funcs + ) + assert ( + trans_bas_sum._n_input_dimensionality + == trans_bas_a._n_input_dimensionality + trans_bas_b._n_input_dimensionality + ) + if basis_cls not in [basis.AdditiveBasis, basis.MultiplicativeBasis]: + assert trans_bas_sum._basis1.n_basis_funcs == n_basis_funcs_a + assert trans_bas_sum._basis2.n_basis_funcs == n_basis_funcs_b + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_transformerbasis_multiplication(basis_cls, basis_class_specific_params): + n_basis_funcs_a = 5 + n_basis_funcs_b = n_basis_funcs_a * 2 + bas1 = CombinedBasis().instantiate_basis( + n_basis_funcs_a, basis_cls, basis_class_specific_params, window_size=10 + ) + trans_bas_a = basis.TransformerBasis( + bas1.set_input_shape(*([1] * bas1._n_input_dimensionality)) + ) + bas2 = CombinedBasis().instantiate_basis( + n_basis_funcs_b, basis_cls, basis_class_specific_params, window_size=10 + ) + trans_bas_b = basis.TransformerBasis( + bas2.set_input_shape(*([1] * bas2._n_input_dimensionality)) + ) + trans_bas_prod = trans_bas_a * trans_bas_b + assert isinstance(trans_bas_prod, basis.TransformerBasis) + assert isinstance(trans_bas_prod._basis, basis.MultiplicativeBasis) + assert ( + trans_bas_prod.n_basis_funcs + == trans_bas_a.n_basis_funcs * trans_bas_b.n_basis_funcs + ) + assert ( + trans_bas_prod._n_input_dimensionality + == trans_bas_a._n_input_dimensionality + trans_bas_b._n_input_dimensionality + ) + if basis_cls not in [basis.AdditiveBasis, basis.MultiplicativeBasis]: + assert trans_bas_prod._basis1.n_basis_funcs == n_basis_funcs_a + assert trans_bas_prod._basis2.n_basis_funcs == n_basis_funcs_b + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +@pytest.mark.parametrize( + "exponent, error_type, error_message", + [ + (2, does_not_raise, None), + (5, does_not_raise, None), + (0.5, TypeError, "Exponent should be an integer"), + (-1, ValueError, "Exponent should be a non-negative integer"), + ], +) +def test_transformerbasis_exponentiation( + basis_cls, exponent: int, error_type, error_message, basis_class_specific_params +): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + trans_bas = basis.TransformerBasis( + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) + ) + + if not isinstance(exponent, int): + with pytest.raises(error_type, match=error_message): + trans_bas_exp = trans_bas**exponent + assert isinstance(trans_bas_exp, basis.TransformerBasis) + assert isinstance(trans_bas_exp._basis, basis.MultiplicativeBasis) + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_transformerbasis_dir(basis_cls, basis_class_specific_params): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + trans_bas = basis.TransformerBasis( + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) + ) + for attr_name in ( + "fit", + "transform", + "fit_transform", + "n_basis_funcs", + "mode", + "window_size", + ): + if ( + attr_name == "window_size" + and "Conv" not in trans_bas._basis.__class__.__name__ + ): + continue + assert attr_name in dir(trans_bas) + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes("Conv"), +) +def test_transformerbasis_sk_clone_kernel_noned(basis_cls, basis_class_specific_params): + orig_bas = CombinedBasis().instantiate_basis( + 10, basis_cls, basis_class_specific_params, window_size=20 + ) + orig_bas.set_input_shape(*([1] * orig_bas._n_input_dimensionality)) + trans_bas = basis.TransformerBasis(orig_bas) + + # kernel should be saved in the object after fit + trans_bas.fit(np.random.randn(100, 20)) + assert isinstance(trans_bas.kernel_, np.ndarray) + + # cloning should set kernel_ to None + trans_bas_clone = sk_clone(trans_bas) + + # the original object should still have kernel_ + assert isinstance(trans_bas.kernel_, np.ndarray) + # but the clone should not have one + assert trans_bas_clone.kernel_ is None + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +@pytest.mark.parametrize("n_basis_funcs", [5]) +def test_transformerbasis_pickle( + tmpdir, basis_cls, n_basis_funcs, basis_class_specific_params +): + bas = CombinedBasis().instantiate_basis( + n_basis_funcs, basis_cls, basis_class_specific_params, window_size=10 + ) + # the test that tries cross-validation with n_jobs = 2 already should test this + trans_bas = basis.TransformerBasis( + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) + ) + filepath = tmpdir / "transformerbasis.pickle" + with open(filepath, "wb") as f: + pickle.dump(trans_bas, f) + with open(filepath, "rb") as f: + trans_bas2 = pickle.load(f) + + assert isinstance(trans_bas2, basis.TransformerBasis) + if basis_cls in [basis.AdditiveBasis, basis.MultiplicativeBasis]: + for bas in [ + getattr(trans_bas2._basis, attr) for attr in ("_basis1", "_basis2") + ]: + assert bas.n_basis_funcs == n_basis_funcs + else: + assert trans_bas2.n_basis_funcs == n_basis_funcs + + +@pytest.mark.parametrize( + "set_input, expectation", + [ + (True, does_not_raise()), + ( + False, + pytest.raises( + RuntimeError, + match="Cannot initialize TransformerBasis: the provided basis has no defined input shape", + ), + ), + ], +) +@pytest.mark.parametrize( + "inp", [np.ones((10,)), np.ones((10, 1)), np.ones((10, 2)), np.ones((10, 2, 3))] +) +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_to_transformer_and_set_input( + basis_cls, inp, set_input, expectation, basis_class_specific_params +): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + if set_input: + bas.set_input_shape(*([inp] * bas._n_input_dimensionality)) + with expectation: + bas.to_transformer() + + +@pytest.mark.parametrize( + "inp, expectation", + [ + (np.ones((10,)), pytest.raises(ValueError, match="X must be 2-")), + (np.ones((10, 1)), does_not_raise()), + (np.ones((10, 2)), does_not_raise()), + (np.ones((10, 2, 3)), pytest.raises(ValueError, match="X must be 2-")), + ], +) +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_transformer_fit(basis_cls, inp, basis_class_specific_params, expectation): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + transformer = bas.set_input_shape( + *([inp] * bas._n_input_dimensionality) + ).to_transformer() + X = np.concatenate( + [inp.reshape(inp.shape[0], -1)] * bas._n_input_dimensionality, axis=1 + ) + transformer.fit(X) + if "Conv" in basis_cls.__name__: + assert transformer.kernel_ is not None + + # try and pass segmented time series + if isinstance(bas, (basis.AdditiveBasis, basis.MultiplicativeBasis)): + expectation = pytest.raises(ValueError, match="The input must be a 2-") + + with expectation: + transformer.fit(*([inp] * bas._n_input_dimensionality)) + + +@pytest.mark.parametrize( + "inp", + [ + np.ones((10, 1)), + np.ones((10, 2)), + ], +) +@pytest.mark.parametrize( + "delta_input, expectation", + [ + (0, does_not_raise()), + (1, pytest.raises(ValueError, match="Input mismatch: expected ")), + (-1, pytest.raises(ValueError, match="Input mismatch: expected ")), + ], +) +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_transformer_fit_input_shape_mismatch( + basis_cls, delta_input, inp, basis_class_specific_params, expectation +): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + transformer = bas.set_input_shape( + *([inp] * bas._n_input_dimensionality) + ).to_transformer() + X = np.random.randn(10, int(sum(bas._n_basis_input_) + delta_input)) + with expectation: + transformer.fit(X) + + +@pytest.mark.parametrize( + "inp", + [ + np.random.randn( + 10, + ), + np.random.randn(10, 1), + np.random.randn(10, 2), + np.random.randn(10, 2, 3), + ], +) +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_transformer_transform(basis_cls, inp, basis_class_specific_params): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + transformer = bas.set_input_shape( + *([inp] * bas._n_input_dimensionality) + ).to_transformer() + X = np.concatenate( + [inp.reshape(inp.shape[0], -1)] * bas._n_input_dimensionality, axis=1 + ) + transformer.fit(X) + + out = transformer.transform(X) + out2 = bas.compute_features(*([inp] * bas._n_input_dimensionality)) + + assert np.array_equal(out, out2, equal_nan=True) + + +@pytest.mark.parametrize( + "inp", + [ + np.random.randn( + 10, + ), + np.random.randn(10, 1), + np.random.randn(10, 2), + np.random.randn(10, 2, 3), + ], +) +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_transformer_fit_transform(basis_cls, inp, basis_class_specific_params): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + transformer = bas.set_input_shape( + *([inp] * bas._n_input_dimensionality) + ).to_transformer() + X = np.concatenate( + [inp.reshape(inp.shape[0], -1)] * bas._n_input_dimensionality, axis=1 + ) + + out = transformer.fit_transform(X) + out2 = bas.compute_features(*([inp] * bas._n_input_dimensionality)) + + assert np.array_equal(out, out2, equal_nan=True) + + +@pytest.mark.parametrize( + "inp", + [ + np.ones((10, 1)), + np.ones((10, 2)), + ], +) +@pytest.mark.parametrize( + "delta_input, expectation", + [ + (0, does_not_raise()), + (1, pytest.raises(ValueError, match="Input mismatch: expected ")), + (-1, pytest.raises(ValueError, match="Input mismatch: expected ")), + ], +) +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_transformer_fit_transform_input_shape_mismatch( + basis_cls, delta_input, inp, basis_class_specific_params, expectation +): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + transformer = bas.set_input_shape( + *([inp] * bas._n_input_dimensionality) + ).to_transformer() + X = np.random.randn(10, int(sum(bas._n_basis_input_) + delta_input)) + with expectation: + transformer.fit_transform(X) + + +@pytest.mark.parametrize( + "inp, expectation", + [ + (np.ones((10,)), pytest.raises(ValueError, match="X must be 2-")), + (np.ones((10, 1)), does_not_raise()), + (np.ones((10, 2)), does_not_raise()), + (np.ones((10, 2, 3)), pytest.raises(ValueError, match="X must be 2-")), + ], +) +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_transformer_fit_transform_input_struct( + basis_cls, inp, basis_class_specific_params, expectation +): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + transformer = bas.set_input_shape( + *([inp] * bas._n_input_dimensionality) + ).to_transformer() + X = np.concatenate( + [inp.reshape(inp.shape[0], -1)] * bas._n_input_dimensionality, axis=1 + ) + transformer.fit_transform(X) + + if "Conv" in basis_cls.__name__: + assert transformer.kernel_ is not None + + # try and pass a tuple of time series + if isinstance(bas, (basis.AdditiveBasis, basis.MultiplicativeBasis)): + expectation = pytest.raises(ValueError, match="The input must be a 2-") + + with expectation: + transformer.fit(*([inp] * bas._n_input_dimensionality)) From 59c0ef2ec97e30e8d4ab1546f4c3d98e01bfa65e Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 9 Dec 2024 15:06:01 -0500 Subject: [PATCH 12/37] added tests for fit and fit_transform on input types --- src/nemos/basis/_transformer_basis.py | 6 +++++- tests/test_transformer_basis.py | 12 +++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index 165079a3..35bcf860 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -447,7 +447,7 @@ def _check_input(self, X: FeatureMatrix, y=None): If the input is not a 2-d array or if the number of columns does not match the expected number of inputs. """ ndim = getattr(X, "ndim", None) - if ndim is None or y is not None: + if ndim is None: raise ValueError("The input must be a 2-dimensional array.") elif ndim != 2: @@ -460,3 +460,7 @@ def _check_input(self, X: FeatureMatrix, y=None): f"Input mismatch: expected {sum(self.n_basis_input_)} inputs, but got {X.shape[1]} columns in X.\n" "To modify the required number of inputs, call `set_input_shape` before using `fit` or `fit_transform`." ) + + if y is not None and y.shape[0] != X.shape[0]: + raise ValueError("X and y must have the same number of samples. " + f"X has {X.shpae[0]} samples, while y has {y.shape[0]} samples.") \ No newline at end of file diff --git a/tests/test_transformer_basis.py b/tests/test_transformer_basis.py index aadd9174..3e0b33b1 100644 --- a/tests/test_transformer_basis.py +++ b/tests/test_transformer_basis.py @@ -381,7 +381,7 @@ def test_transformerbasis_sk_clone_kernel_noned(basis_cls, basis_class_specific_ trans_bas = basis.TransformerBasis(orig_bas) # kernel should be saved in the object after fit - trans_bas.fit(np.random.randn(100, 20)) + trans_bas.fit(np.random.randn(100, 1)) assert isinstance(trans_bas.kernel_, np.ndarray) # cloning should set kernel_ to None @@ -485,7 +485,8 @@ def test_transformer_fit(basis_cls, inp, basis_class_specific_params, expectatio # try and pass segmented time series if isinstance(bas, (basis.AdditiveBasis, basis.MultiplicativeBasis)): - expectation = pytest.raises(ValueError, match="The input must be a 2-") + if inp.ndim == 2: + expectation = pytest.raises(ValueError, match="Input mismatch: expected ") with expectation: transformer.fit(*([inp] * bas._n_input_dimensionality)) @@ -653,8 +654,9 @@ def test_transformer_fit_transform_input_struct( assert transformer.kernel_ is not None # try and pass a tuple of time series - if isinstance(bas, (basis.AdditiveBasis, basis.MultiplicativeBasis)): - expectation = pytest.raises(ValueError, match="The input must be a 2-") - + if isinstance(bas, (basis.AdditiveBasis, basis.MultiplicativeBasis)) and inp.ndim != 2: + expectation = pytest.raises(ValueError, match="X must be 2-") + elif isinstance(bas, (basis.AdditiveBasis, basis.MultiplicativeBasis)) and inp.ndim == 2: + expectation = pytest.raises(ValueError, match="Input mismatch: expected") with expectation: transformer.fit(*([inp] * bas._n_input_dimensionality)) From b5aac2a7cbd6ce2ada025044cbe4ad4db816ace3 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 9 Dec 2024 16:27:25 -0500 Subject: [PATCH 13/37] added pipeline testing with composite bases --- tests/test_transformer_basis.py | 57 ++++++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/tests/test_transformer_basis.py b/tests/test_transformer_basis.py index 3e0b33b1..91c81803 100644 --- a/tests/test_transformer_basis.py +++ b/tests/test_transformer_basis.py @@ -7,7 +7,8 @@ from sklearn.base import clone as sk_clone from nemos import basis - +import nemos as nmo +from sklearn.pipeline import Pipeline @pytest.mark.parametrize( "basis_cls", @@ -660,3 +661,57 @@ def test_transformer_fit_transform_input_struct( expectation = pytest.raises(ValueError, match="Input mismatch: expected") with expectation: transformer.fit(*([inp] * bas._n_input_dimensionality)) + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +@pytest.mark.parametrize( + "inp", + [ + np.random.randn(100,), + np.random.randn(100, 1), + np.random.randn(100, 2), + np.random.randn(100, 1, 2), + ], +) +def test_transformer_in_pipeline(basis_cls, inp, basis_class_specific_params): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + transformer = bas.set_input_shape( + *([inp] * bas._n_input_dimensionality) + ).to_transformer() + + # fit outside pipeline + X = bas.compute_features(*([inp] * bas._n_input_dimensionality)) + log_mu = X.dot(0.005 * np.ones(X.shape[1])) + y = np.full(X.shape[0], 0) + y[~np.isnan(log_mu)] = np.random.poisson(np.exp(log_mu[~np.isnan(log_mu)] - np.nanmean(log_mu))) + model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.001).fit(X, y) + + # pipeline + pipe = Pipeline( + [ + ("bas", transformer), + ("glm", nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.001)) + ] + ) + x = np.concatenate( + [inp.reshape(inp.shape[0], -1)] * bas._n_input_dimensionality, axis=1 + ) + pipe.fit(x, y) + np.testing.assert_allclose(pipe["glm"].coef_, model.coef_) + + # set basis & refit + if isinstance(bas, (basis.AdditiveBasis, basis.MultiplicativeBasis)): + pipe.set_params(bas__basis2__n_basis_funcs=4) + assert bas.basis2.n_basis_funcs == 5 # make sure that the change did not affect bas + X = bas.set_params(basis2__n_basis_funcs=4).compute_features(*([inp] * bas._n_input_dimensionality)) + else: + pipe.set_params(bas__n_basis_funcs=4) + assert bas.n_basis_funcs == 5 # make sure that the change did not affect bas + X = bas.set_params(n_basis_funcs=4).compute_features(*([inp] * bas._n_input_dimensionality)) + pipe.fit(x, y) + model.fit(X, y) + np.testing.assert_allclose(pipe["glm"].coef_, model.coef_) From dc7f554fcb5e32c8cddb1a09602ac9947dd75e2f Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 12 Dec 2024 16:54:39 -0500 Subject: [PATCH 14/37] removed kernel attr for eval bases. moved to conv bases only. create a private method _set_input_independent_sates --- src/nemos/basis/_basis.py | 33 +++++------------- src/nemos/basis/_basis_mixin.py | 49 ++++++++++++++++----------- src/nemos/basis/_transformer_basis.py | 8 +++-- tests/test_basis.py | 8 ++--- tests/test_transformer_basis.py | 38 +++++++++++++++------ 5 files changed, 76 insertions(+), 60 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index 3dd7f3c7..01961634 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -152,9 +152,6 @@ def __init__( self._n_basis_input_ = getattr(self, "_n_basis_input_", None) self._input_shape_ = getattr(self, "_input_shape_", None) - # set by set_kernel - self.kernel_ = None - # initialize parent to None. This should not end in "_" because it is # a permanent property of a basis, defined at composite basis init self._parent = None @@ -240,7 +237,7 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: if self._n_basis_input_ is None: self.set_input_shape(*xi) self._check_input_shape_consistency(*xi) - self.set_kernel() + self._set_input_independent_states() return self._compute_features(*xi) @abc.abstractmethod @@ -273,19 +270,13 @@ def setup_basis(self, *xi: ArrayLike) -> FeatureMatrix: pass @abc.abstractmethod - def set_kernel(self): - """Set kernel for conv basis and return self or just return self for eval. - - For the basis API to work correctly, specifically, for the `_fit_basis` - method to work as intended, this method should set **all** state attributes - that do not require inspection of input time series. - - This method currently "just" sets the kernel because this is the only such state - but if in the future new states will be added, they must be funneled through this - method. + def _set_input_independent_states(self): + """ + Compute all the basis states that do not depend on the input. - Note that the name of this method can and should be refactored in case more such - states will be set in the future. + An example of such state is the kernel_ for Conv baisis, which can be computed + without any input (it only depends on the basis type, the window size and the + number of basis elements). """ pass @@ -325,7 +316,8 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): is not set in this method, then ``compute_features`` (equivalent to ``fit_transform``) will break. Separating states related to the input (settable with this method) and states that are unrelated - from the input (settable with ``set_kernel``) is a deliberate design choice that improves modularity. + from the input (settable with ``set_kernel`` for Conv bases) is a deliberate design choice + that improves modularity. """ if isinstance(xi, tuple): @@ -436,13 +428,6 @@ def _check_transform_input( return xi - def _check_has_kernel(self) -> None: - """Check that the kernel is pre-computed.""" - if self.mode == "conv" and self.kernel_ is None: - raise ValueError( - "You must call `_set_kernel` before `_compute_features` when mode =`conv`." - ) - def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: """Evaluate the basis set on a grid of equi-spaced sample points. diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 423b9588..f625d790 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -70,13 +70,12 @@ def setup_basis(self, *xi: NDArray) -> Basis: : The basis with ready for evaluation. """ - self.set_kernel() self.set_input_shape(*xi) return self - def set_kernel(self) -> "EvalBasisMixin": + def _set_input_independent_states(self) -> "EvalBasisMixin": """ - Prepare or compute the convolutional kernel for the basis functions. + Compute all the basis states that do not depend on the input. For EvalBasisMixin, this method might not perform any operation but simply return the instance itself, as no kernel preparation is necessary. @@ -120,6 +119,7 @@ class ConvBasisMixin: def __init__( self, n_basis_funcs: int, window_size: int, conv_kwargs: Optional[dict] = None ): + self.kernel_ = None self.window_size = window_size self.conv_kwargs = {} if conv_kwargs is None else conv_kwargs self._n_basis_funcs = n_basis_funcs @@ -154,7 +154,7 @@ def _compute_features(self, *xi: NDArray): """ if self.kernel_ is None: raise ValueError( - "You must call `_set_kernel` before `_compute_features`! " + "You must call `setup_basis` before `_compute_features`! " "Convolution kernel is not set." ) # before calling the convolve, check that the input matches @@ -188,6 +188,14 @@ def setup_basis(self, *xi: NDArray) -> Basis: self.set_input_shape(*xi) return self + def _set_input_independent_states(self): + """ + Compute all the basis states that do not depend on the input. + + For Conv mixin the only attribute is the kernel. + """ + return self.set_kernel() + def set_kernel(self) -> "ConvBasisMixin": """ Prepare or compute the convolutional kernel for the basis functions. @@ -296,6 +304,13 @@ def _check_convolution_kwargs(conv_kwargs: dict): f"Allowed convolution keyword arguments are: {convolve_configs}." ) + def _check_has_kernel(self) -> None: + """Check that the kernel is pre-computed.""" + if self.kernel_ is None: + raise ValueError( + "You must call `_set_kernel` before `_compute_features` for Conv basis." + ) + class BasisTransformerMixin: """Mixin class for constructing a transformer.""" @@ -393,29 +408,25 @@ def setup_basis(self, *xi: NDArray) -> Basis: : The basis with ready for evaluation. """ - self.set_kernel() - self.basis1.set_input_shape(*xi[: self._basis1._n_input_dimensionality]) - self.basis2.set_input_shape(*xi[self._basis1._n_input_dimensionality :]) - return self + # setup both input independent + self._set_input_independent_states() - def set_kernel(self) -> Basis: - """Call set_kernel on the basis elements. + # and input dependent states + self.set_input_shape(*xi) - If any of the basis elements is in "conv" mode, it will prepare its kernels for the convolution. - Addi - Also grabs input shapes if provided, similar to what sklean transformer `fit` method does + return self - Parameters - ---------- + def _set_input_independent_states(self): + """ + Compute the input dependent states for traversing the composite basis. Returns ------- : - The basis with the kernels set up. + The basis with the states stored as attributes of each component. """ - self._basis1.set_kernel() - self._basis2.set_kernel() - return self + self.basis1._set_input_independent_states() + self.basis2._set_input_independent_states() def _check_input_shape_consistency(self, *xi: NDArray): """Check the input shape consistency for all basis elements.""" diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index 35bcf860..e6b237f5 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -151,7 +151,7 @@ def fit(self, X: FeatureMatrix, y=None): >>> transformer_fitted = transformer.fit(X) """ self._check_input(X, y) - self._basis.set_kernel() + self._basis.setup_basis(*self._unpack_inputs(X)) return self def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: @@ -462,5 +462,7 @@ def _check_input(self, X: FeatureMatrix, y=None): ) if y is not None and y.shape[0] != X.shape[0]: - raise ValueError("X and y must have the same number of samples. " - f"X has {X.shpae[0]} samples, while y has {y.shape[0]} samples.") \ No newline at end of file + raise ValueError( + "X and y must have the same number of samples. " + f"X has {X.shpae[0]} samples, while y has {y.shape[0]} samples." + ) diff --git a/tests/test_basis.py b/tests/test_basis.py index 55f4335f..4315237d 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -2490,7 +2490,7 @@ def test_fit_kernel( n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) bas = basis_a_obj + basis_b_obj - bas.set_kernel() + bas.setup_basis(*([np.ones(10)] * bas._n_input_dimensionality)) def check_kernel(basis_obj): has_kern = [] @@ -2524,7 +2524,7 @@ def test_transform_fails( else: context = pytest.raises( ValueError, - match="You must call `_set_kernel` before `_compute_features`", + match="You must call `setup_basis` before `_compute_features`", ) with context: x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality @@ -3418,7 +3418,7 @@ def test_fit_kernel( n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) bas = basis_a_obj * basis_b_obj - bas.set_kernel() + bas._set_input_independent_states() def check_kernel(basis_obj): has_kern = [] @@ -3452,7 +3452,7 @@ def test_transform_fails( else: context = pytest.raises( ValueError, - match="You must call `_set_kernel` before `_compute_features`", + match="You must call `setup_basis` before `_compute_features`", ) with context: x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality diff --git a/tests/test_transformer_basis.py b/tests/test_transformer_basis.py index 91c81803..50b58f13 100644 --- a/tests/test_transformer_basis.py +++ b/tests/test_transformer_basis.py @@ -5,10 +5,11 @@ import pytest from conftest import CombinedBasis, list_all_basis_classes from sklearn.base import clone as sk_clone +from sklearn.pipeline import Pipeline -from nemos import basis import nemos as nmo -from sklearn.pipeline import Pipeline +from nemos import basis + @pytest.mark.parametrize( "basis_cls", @@ -655,13 +656,20 @@ def test_transformer_fit_transform_input_struct( assert transformer.kernel_ is not None # try and pass a tuple of time series - if isinstance(bas, (basis.AdditiveBasis, basis.MultiplicativeBasis)) and inp.ndim != 2: + if ( + isinstance(bas, (basis.AdditiveBasis, basis.MultiplicativeBasis)) + and inp.ndim != 2 + ): expectation = pytest.raises(ValueError, match="X must be 2-") - elif isinstance(bas, (basis.AdditiveBasis, basis.MultiplicativeBasis)) and inp.ndim == 2: + elif ( + isinstance(bas, (basis.AdditiveBasis, basis.MultiplicativeBasis)) + and inp.ndim == 2 + ): expectation = pytest.raises(ValueError, match="Input mismatch: expected") with expectation: transformer.fit(*([inp] * bas._n_input_dimensionality)) + @pytest.mark.parametrize( "basis_cls", list_all_basis_classes(), @@ -669,7 +677,9 @@ def test_transformer_fit_transform_input_struct( @pytest.mark.parametrize( "inp", [ - np.random.randn(100,), + np.random.randn( + 100, + ), np.random.randn(100, 1), np.random.randn(100, 2), np.random.randn(100, 1, 2), @@ -687,14 +697,16 @@ def test_transformer_in_pipeline(basis_cls, inp, basis_class_specific_params): X = bas.compute_features(*([inp] * bas._n_input_dimensionality)) log_mu = X.dot(0.005 * np.ones(X.shape[1])) y = np.full(X.shape[0], 0) - y[~np.isnan(log_mu)] = np.random.poisson(np.exp(log_mu[~np.isnan(log_mu)] - np.nanmean(log_mu))) + y[~np.isnan(log_mu)] = np.random.poisson( + np.exp(log_mu[~np.isnan(log_mu)] - np.nanmean(log_mu)) + ) model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.001).fit(X, y) # pipeline pipe = Pipeline( [ ("bas", transformer), - ("glm", nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.001)) + ("glm", nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.001)), ] ) x = np.concatenate( @@ -706,12 +718,18 @@ def test_transformer_in_pipeline(basis_cls, inp, basis_class_specific_params): # set basis & refit if isinstance(bas, (basis.AdditiveBasis, basis.MultiplicativeBasis)): pipe.set_params(bas__basis2__n_basis_funcs=4) - assert bas.basis2.n_basis_funcs == 5 # make sure that the change did not affect bas - X = bas.set_params(basis2__n_basis_funcs=4).compute_features(*([inp] * bas._n_input_dimensionality)) + assert ( + bas.basis2.n_basis_funcs == 5 + ) # make sure that the change did not affect bas + X = bas.set_params(basis2__n_basis_funcs=4).compute_features( + *([inp] * bas._n_input_dimensionality) + ) else: pipe.set_params(bas__n_basis_funcs=4) assert bas.n_basis_funcs == 5 # make sure that the change did not affect bas - X = bas.set_params(n_basis_funcs=4).compute_features(*([inp] * bas._n_input_dimensionality)) + X = bas.set_params(n_basis_funcs=4).compute_features( + *([inp] * bas._n_input_dimensionality) + ) pipe.fit(x, y) model.fit(X, y) np.testing.assert_allclose(pipe["glm"].coef_, model.coef_) From aeeb89c7693a64f3294390ae5142b3010afc5ed3 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 12 Dec 2024 18:31:42 -0500 Subject: [PATCH 15/37] runtime wrapping --- src/nemos/basis/_transformer_basis.py | 41 +++++++++++++++++++++++---- tests/test_transformer_basis.py | 18 +++++++++++- 2 files changed, 52 insertions(+), 7 deletions(-) diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index e6b237f5..dd8ea94e 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -1,6 +1,7 @@ from __future__ import annotations import copy +from functools import wraps from typing import TYPE_CHECKING, List import numpy as np @@ -11,6 +12,17 @@ from ._basis import Basis +def transformer_chaining(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + # Call the wrapped function and capture its return value + result = func(*args, **kwargs) + + # If the method returns the inner `self`, replace it with the outer `self` (no deepcopy here). + return self if result is self._basis else result + + return wrapper + class TransformerBasis: """Basis as ``scikit-learn`` transformers. @@ -61,16 +73,19 @@ class TransformerBasis: Cross-validated number of basis: {'compute_features__n_basis_funcs': 10} """ + _chainable_methods = ("set_kernel", "set_input_shape", "_set_input_independent_states", "setup_basis") + def __init__(self, basis: Basis): - self._check_initialized(basis) self._basis = copy.deepcopy(basis) + @staticmethod def _check_initialized(basis): if basis._n_basis_input_ is None: raise RuntimeError( "Cannot initialize TransformerBasis: the provided basis has no defined input shape. " - "Please call `set_input_shape` on the basis before initializing the transformer." + "Please call `set_input_shape` on the basis before calling `fit`, `transform`, or " + "`fit_transform`." ) @property @@ -145,11 +160,12 @@ def fit(self, X: FeatureMatrix, y=None): >>> # Example input >>> X = np.random.normal(size=(100, 2)) - >>> # Define and fit tranformation basis + >>> # Define, setup and fit transformer basis >>> basis = MSplineEval(10) - >>> transformer = TransformerBasis(basis) + >>> transformer = TransformerBasis(basis).set_input_shape(2) >>> transformer_fitted = transformer.fit(X) """ + self._check_initialized(self._basis) self._check_input(X, y) self._basis.setup_basis(*self._unpack_inputs(X)) return self @@ -191,6 +207,7 @@ def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: >>> # Transform basis >>> feature_transformed = transformer.transform(X[:, 0:1]) """ + self._check_initialized(self._basis) # transpose does not work with pynapple # can't use func(*X.T) to unwrap return self._basis._compute_features(*self._unpack_inputs(X)) @@ -221,17 +238,22 @@ def fit_transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: >>> from nemos.basis import MSplineEval, TransformerBasis >>> # Example input - >>> X = np.random.normal(size=(100, 1)) + >>> n_inputs = 2 + >>> X = np.random.normal(size=(100, 2)) >>> # Define tranformation basis >>> basis = MSplineEval(10) + >>> # Prepare basis to process 2 inputs + >>> # This step must be done before + >>> basis.set_input_shape(n_inputs) + >>> transformer = TransformerBasis(basis) >>> # Fit and transform basis >>> feature_transformed = transformer.fit_transform(X) """ self.fit(X, y=y) - return self._basis.compute_features(*self._unpack_inputs(X)) + return self.transform(X) def __getstate__(self): """ @@ -268,6 +290,13 @@ def __getattr__(self, name: str): >>> trans_bas.n_basis_funcs 5 """ + # set chainable methods decorating the basis method + # this must be done lazily (runtime) when the attribute is requested + # otherwise it will create an infinite loop when pickling + if name in self._chainable_methods: + method = getattr(self._basis, name, None) + if method is not None: + return transformer_chaining(method).__get__(self) return getattr(self._basis, name) def __setattr__(self, name: str, value) -> None: diff --git a/tests/test_transformer_basis.py b/tests/test_transformer_basis.py index 50b58f13..e8245051 100644 --- a/tests/test_transformer_basis.py +++ b/tests/test_transformer_basis.py @@ -68,6 +68,21 @@ def test_to_transformer_and_constructor_are_equivalent( trans_bas_a._basis.__dict__.pop("_decay_rates", 1) == trans_bas_b._basis.__dict__.pop("_decay_rates", 1) ) + + # extract the wrapped func for these methods + wrapped_methods_a = {} + for method in trans_bas_a._chainable_methods: + out = trans_bas_a._basis.__dict__.pop(method, False) + val = out if out is False else out.__func__.__qualname__ + wrapped_methods_a.update({method: val}) + + wrapped_methods_b = {} + for method in trans_bas_b._chainable_methods: + out = trans_bas_b._basis.__dict__.pop(method, False) + val = out if out is False else out.__func__.__qualname__ + wrapped_methods_b.update({method: val}) + + assert wrapped_methods_a == wrapped_methods_b assert trans_bas_a._basis.__dict__ == trans_bas_b._basis.__dict__ @@ -454,8 +469,9 @@ def test_to_transformer_and_set_input( ) if set_input: bas.set_input_shape(*([inp] * bas._n_input_dimensionality)) + trans = bas.to_transformer() with expectation: - bas.to_transformer() + trans.fit(inp) @pytest.mark.parametrize( From 387bcad338a5c82237173159d80a0d8832426778 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 12 Dec 2024 19:06:10 -0500 Subject: [PATCH 16/37] fixed chaining, doctests --- src/nemos/basis/_basis_mixin.py | 2 +- src/nemos/basis/_transformer_basis.py | 56 +++++++++++++++++++-------- tests/test_transformer_basis.py | 7 +++- 3 files changed, 45 insertions(+), 20 deletions(-) diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index f625d790..d1f9f514 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -328,7 +328,7 @@ def to_transformer(self) -> TransformerBasis: >>> from sklearn.model_selection import GridSearchCV >>> # load some data >>> X, y = np.random.normal(size=(30, 1)), np.random.poisson(size=30) - >>> basis = nmo.basis.RaisedCosineLinearEval(10).to_transformer() + >>> basis = nmo.basis.RaisedCosineLinearEval(10).set_input_shape(1).to_transformer() >>> glm = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=1.) >>> pipeline = Pipeline([("basis", basis), ("glm", glm)]) >>> param_grid = dict( diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index dd8ea94e..f0fdc9df 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -23,6 +23,7 @@ def wrapper(self, *args, **kwargs): return wrapper + class TransformerBasis: """Basis as ``scikit-learn`` transformers. @@ -73,11 +74,16 @@ class TransformerBasis: Cross-validated number of basis: {'compute_features__n_basis_funcs': 10} """ - _chainable_methods = ("set_kernel", "set_input_shape", "_set_input_independent_states", "setup_basis") + _chainable_methods = ( + "set_kernel", + "set_input_shape", + "_set_input_independent_states", + "setup_basis", + ) def __init__(self, basis: Basis): self._basis = copy.deepcopy(basis) - + self._wrapped_methods = {} # Cache for wrapped methods @staticmethod def _check_initialized(basis): @@ -199,13 +205,13 @@ def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: >>> # Before calling `fit` the convolution kernel is not set >>> transformer.kernel_ - >>> transformer_fitted = transformer.fit(X) + >>> transformer_fitted = transformer.set_input_shape(2).fit(X) >>> # Now the convolution kernel is initialized and has shape (window_size, n_basis_funcs) >>> transformer_fitted.kernel_.shape (200, 10) >>> # Transform basis - >>> feature_transformed = transformer.transform(X[:, 0:1]) + >>> feature_transformed = transformer.transform(X) """ self._check_initialized(self._basis) # transpose does not work with pynapple @@ -245,7 +251,7 @@ def fit_transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: >>> basis = MSplineEval(10) >>> # Prepare basis to process 2 inputs >>> # This step must be done before - >>> basis.set_input_shape(n_inputs) + >>> basis = basis.set_input_shape(n_inputs) >>> transformer = TransformerBasis(basis) @@ -262,6 +268,10 @@ def __getstate__(self): See https://docs.python.org/3/library/pickle.html#object.__getstate__ and https://docs.python.org/3/library/pickle.html#pickle-state """ + # this is the only state needed at initalization + # returning the cached wrapped methods would create + # a circular binding of the state to self (i.e. infinite recursion when + # unpickling). return {"_basis": self._basis} def __setstate__(self, state): @@ -275,11 +285,16 @@ def __setstate__(self, state): and https://docs.python.org/3/library/pickle.html#pickle-state """ self._basis = state["_basis"] + self._wrapped_methods = {} # Reinitialize the cache def __getattr__(self, name: str): """ Enable easy access to attributes of the underlying Basis object. + This method chaces all chainable methods (methods returning self) in a dicitonary. + These mehtods are created the first time they are accessed by decorating the `self._basis.name` + and cached for future use. + Examples -------- >>> from nemos import basis @@ -290,14 +305,21 @@ def __getattr__(self, name: str): >>> trans_bas.n_basis_funcs 5 """ - # set chainable methods decorating the basis method - # this must be done lazily (runtime) when the attribute is requested - # otherwise it will create an infinite loop when pickling + # Check if the method has already been wrapped + if name in self._wrapped_methods: + return self._wrapped_methods[name] + + # Get the original attribute from the basis + attr = getattr(self._basis, name) + + # If the attribute is a callable method, wrap it dynamically if name in self._chainable_methods: - method = getattr(self._basis, name, None) - if method is not None: - return transformer_chaining(method).__get__(self) - return getattr(self._basis, name) + wrapped = transformer_chaining(attr).__get__(self) + self._wrapped_methods[name] = wrapped # Cache the wrapped method + return wrapped + + # For non-callable attributes, return them directly + return attr def __setattr__(self, name: str, value) -> None: r""" @@ -324,13 +346,13 @@ def __setattr__(self, name: str, value) -> None: >>> trans_bas.n_basis_funcs = 20 >>> # not allowed >>> try: - ... trans_bas.random_attribute_name = "some value" + ... trans_bas.rand_attr = "some value" ... except ValueError as e: ... print(repr(e)) - ValueError('Only setting _basis or existing attributes of _basis is allowed.') + ValueError('Only setting _basis or existing attributes of _basis is allowed. Attempt to set `rand_attr`.') """ - # allow self._basis = basis - if name == "_basis" or name == "basis": + # allow self._basis = basis and other attrs of self to be retrievable + if name == "_basis" or name == "basis" or name == "_wrapped_methods": super().__setattr__(name, value) # allow changing existing attributes of self._basis elif hasattr(self._basis, name): @@ -338,7 +360,7 @@ def __setattr__(self, name: str, value) -> None: # don't allow setting any other attribute else: raise ValueError( - "Only setting _basis or existing attributes of _basis is allowed." + f"Only setting _basis or existing attributes of _basis is allowed. Attempt to set `{name}`." ) def __sklearn_clone__(self) -> TransformerBasis: diff --git a/tests/test_transformer_basis.py b/tests/test_transformer_basis.py index e8245051..c301adc5 100644 --- a/tests/test_transformer_basis.py +++ b/tests/test_transformer_basis.py @@ -61,7 +61,7 @@ def test_to_transformer_and_constructor_are_equivalent( assert ( list(trans_bas_a.__dict__.keys()) == list(trans_bas_b.__dict__.keys()) - == ["_basis"] + == ["_basis", "_wrapped_methods"] ) # and those bases are the same assert np.all( @@ -471,7 +471,10 @@ def test_to_transformer_and_set_input( bas.set_input_shape(*([inp] * bas._n_input_dimensionality)) trans = bas.to_transformer() with expectation: - trans.fit(inp) + X = np.concatenate( + [inp.reshape(inp.shape[0], -1)] * bas._n_input_dimensionality, axis=1 + ) + trans.fit(X) @pytest.mark.parametrize( From f53dc34c801816e09058da94a6ce70ace62642fa Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 12 Dec 2024 19:32:32 -0500 Subject: [PATCH 17/37] fixed chaining, doctests, and tests --- src/nemos/basis/_basis.py | 11 +++++++++++ tests/test_basis.py | 2 +- tests/test_pipeline.py | 18 +++++++++++++----- 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index 01961634..742f0a7b 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -838,6 +838,17 @@ class is accidentally removed. ) return [self] + def __sklearn_clone__(self): + """Deep copy the basis. + + This keeps the input shapes. Reinitializing the class, as in the regular + sklearn clone would drop them, making the cross-validation unusable. + """ + bas = copy.deepcopy(self) + if hasattr(bas, "kernel_"): + bas.kernel_ = None + return bas + class AdditiveBasis(CompositeBasisMixin, Basis): """ diff --git a/tests/test_basis.py b/tests/test_basis.py index 4315237d..5858bb85 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -1202,7 +1202,7 @@ def test_transform_fails(self, cls): n_basis_funcs=5, window_size=5, **extra_decay_rates(cls["conv"], 5) ) with pytest.raises( - ValueError, match="You must call `_set_kernel` before `_compute_features`" + ValueError, match="You must call `setup_basis` before `_compute_features`" ): bas._compute_features(np.linspace(0, 1, 10)) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 5e4ce13d..cf75ad3e 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -21,7 +21,7 @@ ) def test_sklearn_transformer_pipeline(bas, poissonGLM_model_instantiation): X, y, model, _, _ = poissonGLM_model_instantiation - bas = TransformerBasis(bas) + bas = TransformerBasis(bas).set_input_shape(*([1] * bas._n_input_dimensionality)) pipe = pipeline.Pipeline([("eval", bas), ("fit", model)]) pipe.fit(X[:, : bas._basis._n_input_dimensionality] ** 2, y) @@ -39,7 +39,7 @@ def test_sklearn_transformer_pipeline(bas, poissonGLM_model_instantiation): ) def test_sklearn_transformer_pipeline_cv(bas, poissonGLM_model_instantiation): X, y, model, _, _ = poissonGLM_model_instantiation - bas = TransformerBasis(bas) + bas = TransformerBasis(bas).set_input_shape(*([1] * bas._n_input_dimensionality)) pipe = pipeline.Pipeline([("basis", bas), ("fit", model)]) param_grid = dict(basis__n_basis_funcs=(4, 5, 10)) gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score="raise") @@ -60,7 +60,7 @@ def test_sklearn_transformer_pipeline_cv_multiprocess( bas, poissonGLM_model_instantiation ): X, y, model, _, _ = poissonGLM_model_instantiation - bas = TransformerBasis(bas) + bas = TransformerBasis(bas).set_input_shape(*([1] * bas._n_input_dimensionality)) pipe = pipeline.Pipeline([("basis", bas), ("fit", model)]) param_grid = dict(basis__n_basis_funcs=(4, 5, 10)) gridsearch = GridSearchCV( @@ -86,8 +86,15 @@ def test_sklearn_transformer_pipeline_cv_directly_over_basis( ): X, y, model, _, _ = poissonGLM_model_instantiation bas = TransformerBasis(bas_cls(5)) + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) pipe = pipeline.Pipeline([("transformerbasis", bas), ("fit", model)]) - param_grid = dict(transformerbasis___basis=(bas_cls(5), bas_cls(10), bas_cls(20))) + param_grid = dict( + transformerbasis___basis=( + bas_cls(5).set_input_shape(*([1] * bas._n_input_dimensionality)), + bas_cls(10).set_input_shape(*([1] * bas._n_input_dimensionality)), + bas_cls(20).set_input_shape(*([1] * bas._n_input_dimensionality)), + ) + ) gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score="raise") gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y) @@ -107,6 +114,7 @@ def test_sklearn_transformer_pipeline_cv_illegal_combination( ): X, y, model, _, _ = poissonGLM_model_instantiation bas = TransformerBasis(bas_cls(5)) + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) pipe = pipeline.Pipeline([("transformerbasis", bas), ("fit", model)]) param_grid = dict( transformerbasis___basis=(bas_cls(5), bas_cls(10), bas_cls(20)), @@ -165,7 +173,7 @@ def test_sklearn_transformer_pipeline_pynapple( ep = nap.IntervalSet(start=[0, 20.5], end=[20, X.shape[0]]) X_nap = nap.TsdFrame(t=np.arange(X.shape[0]), d=X, time_support=ep) y_nap = nap.Tsd(t=np.arange(X.shape[0]), d=y, time_support=ep) - bas = TransformerBasis(bas) + bas = TransformerBasis(bas).set_input_shape(*([1] * bas._n_input_dimensionality)) # fit a pipeline & predict from pynapple pipe = pipeline.Pipeline([("eval", bas), ("fit", model)]) pipe.fit(X_nap[:, : bas._basis._n_input_dimensionality] ** 2, y_nap) From c0d774920ae71ad11717d0cb96f20a834478a342 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 12 Dec 2024 21:42:40 -0500 Subject: [PATCH 18/37] add clone method --- src/nemos/basis/_basis.py | 18 ++++++++--- tests/test_basis.py | 68 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 5 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index 742f0a7b..24976652 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -839,15 +839,23 @@ class is accidentally removed. return [self] def __sklearn_clone__(self): - """Deep copy the basis. + """Clone the basis while preserving input shapes related attributes. This keeps the input shapes. Reinitializing the class, as in the regular sklearn clone would drop them, making the cross-validation unusable. """ - bas = copy.deepcopy(self) - if hasattr(bas, "kernel_"): - bas.kernel_ = None - return bas + # clone recursively + if hasattr(self, "_basis1") and hasattr(self, "_basis2"): + basis1 = self._basis1.__sklearn_clone__() + basis2 = self._basis2.__sklearn_clone__() + klass = self.__class__(basis1, basis2) + + else: + klass = self.__class__(**self.get_params()) + + for attr_name in ["_n_basis_input_", "_input_shape_"]: + setattr(klass, attr_name, getattr(self, attr_name)) + return klass class AdditiveBasis(CompositeBasisMixin, Basis): diff --git a/tests/test_basis.py b/tests/test_basis.py index 5858bb85..fd5e57ad 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -300,6 +300,31 @@ def test_call_vmin_vmax(self, samples, vmin, vmax, expectation, cls): with expectation: bas._evaluate(samples) + + @pytest.mark.parametrize("n_basis", [5, 6]) + @pytest.mark.parametrize("vmin, vmax", [(0, 1), (-1, 1)]) + @pytest.mark.parametrize("inp_num", [1, 2]) + def test_sklearn_clone_eval(self, cls, n_basis, vmin, vmax,inp_num): + bas = cls["eval"](n_basis, bounds=(vmin, vmax), **extra_decay_rates(cls["eval"], n_basis)) + bas.set_input_shape(inp_num) + bas2 = bas.__sklearn_clone__() + assert id(bas) != id(bas2) + assert np.all(bas.__dict__.pop("decay_rates", True) == bas2.__dict__.pop("decay_rates", True)) + assert bas.__dict__ == bas2.__dict__ + + + @pytest.mark.parametrize("n_basis", [5, 6]) + @pytest.mark.parametrize("ws", [10, 20]) + @pytest.mark.parametrize("inp_num", [1, 2]) + def test_sklearn_clone_conv(self, cls, n_basis, ws,inp_num): + bas = cls["conv"](n_basis, window_size=ws, **extra_decay_rates(cls["eval"], n_basis)) + bas.set_input_shape(inp_num) + bas2 = bas.__sklearn_clone__() + assert id(bas) != id(bas2) + assert np.all(bas.__dict__.pop("decay_rates", True) == bas2.__dict__.pop("decay_rates", True)) + assert bas.__dict__ == bas2.__dict__ + + @pytest.mark.parametrize( "attribute, value", [ @@ -1956,6 +1981,49 @@ def test_compute_features_input(self, eval_input): basis_obj = basis.MSplineEval(5) + basis.MSplineEval(5) basis_obj.compute_features(*eval_input) + @pytest.mark.parametrize("n_basis_a", [6]) + @pytest.mark.parametrize("n_basis_b", [5]) + @pytest.mark.parametrize("vmin, vmax", [(-1, 1)]) + @pytest.mark.parametrize("basis_a", list_all_basis_classes()) + @pytest.mark.parametrize("basis_b", list_all_basis_classes()) + @pytest.mark.parametrize("inp_num", [1, 2]) + def test_sklearn_clone(self, basis_a, basis_b, n_basis_a, n_basis_b, vmin, vmax, inp_num, basis_class_specific_params): + """Recursively check cloning.""" + basis_a_obj = self.instantiate_basis( + n_basis_a, basis_a, basis_class_specific_params, window_size=10 + ) + basis_a_obj = basis_a_obj.set_input_shape(*([inp_num] * basis_a_obj._n_input_dimensionality)) + basis_b_obj = self.instantiate_basis( + n_basis_b, basis_b, basis_class_specific_params, window_size=15 + ) + basis_b_obj = basis_b_obj.set_input_shape(*([inp_num]*basis_b_obj._n_input_dimensionality)) + add = basis_a_obj + basis_b_obj + + def filter_attributes(obj, exclude_keys): + return {key: val for key, val in obj.__dict__.items() if key not in exclude_keys} + + def compare(b1, b2): + assert id(b1) != id(b2) + assert b1.__class__.__name__ == b2.__class__.__name__ + if hasattr(b1, "basis1"): + compare(b1.basis1, b2.basis1) + compare(b1.basis2, b2.basis2) + # add all params that are not parent or _basis1,_basis2 + d1 = filter_attributes(b1, exclude_keys=["_basis1", "_basis2", "_parent"]) + d2 = filter_attributes(b2, exclude_keys=["_basis1", "_basis2", "_parent"]) + assert d1 == d2 + else: + decay_rates_b1 = b1.__dict__.get("_decay_rates", -1) + decay_rates_b2 = b2.__dict__.get("_decay_rates", -1) + assert np.array_equal(decay_rates_b1, decay_rates_b2) + d1 = filter_attributes(b1, exclude_keys=["_decay_rates", "_parent"]) + d2 = filter_attributes(b2, exclude_keys=["_decay_rates", "_parent"]) + assert d1 == d2 + + add2 = add.__sklearn_clone__() + compare(add, add2) + + @pytest.mark.parametrize("n_basis_a", [5, 6]) @pytest.mark.parametrize("n_basis_b", [5, 6]) @pytest.mark.parametrize("sample_size", [10, 1000]) From f646ad08863cdfd6149c8005a3181bbc01845b26 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 12 Dec 2024 21:53:12 -0500 Subject: [PATCH 19/37] add clone method --- src/nemos/basis/_basis.py | 13 ++++++++----- src/nemos/basis/_transformer_basis.py | 4 ++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index 24976652..c5615e37 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -838,11 +838,14 @@ class is accidentally removed. ) return [self] - def __sklearn_clone__(self): - """Clone the basis while preserving input shapes related attributes. - - This keeps the input shapes. Reinitializing the class, as in the regular - sklearn clone would drop them, making the cross-validation unusable. + def __sklearn_clone__(self) -> Basis: + """Clone the basis while preserving attributes related to input shapes. + + This method ensures that input shape attributes (e.g., `_n_basis_input_`, + `_input_shape_`) are preserved during cloning. Reinitializing the class + as in the regular sklearn clone would drop these attributes, rendering + cross-validation unusable. + The method also handles recursive cloning for composite basis structures. """ # clone recursively if hasattr(self, "_basis1") and hasattr(self, "_basis2"): diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index f0fdc9df..dc33b29a 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -352,7 +352,7 @@ def __setattr__(self, name: str, value) -> None: ValueError('Only setting _basis or existing attributes of _basis is allowed. Attempt to set `rand_attr`.') """ # allow self._basis = basis and other attrs of self to be retrievable - if name == "_basis" or name == "basis" or name == "_wrapped_methods": + if name in ["_basis", "basis", "_wrapped_methods"]: super().__setattr__(name, value) # allow changing existing attributes of self._basis elif hasattr(self._basis, name): @@ -372,7 +372,7 @@ def __sklearn_clone__(self) -> TransformerBasis: For more info: https://scikit-learn.org/stable/developers/develop.html#cloning """ - cloned_obj = TransformerBasis(copy.deepcopy(self._basis)) + cloned_obj = TransformerBasis(self._basis.__sklearn_clone__()) cloned_obj._basis.kernel_ = None return cloned_obj From 6e4d42369b45601a141dd2d9dcc7ef2cc8b44c74 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 12 Dec 2024 21:56:03 -0500 Subject: [PATCH 20/37] linted --- tests/test_basis.py | 58 ++++++++++++++++++++++++++++++++------------- 1 file changed, 42 insertions(+), 16 deletions(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index fd5e57ad..12f967e0 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -300,31 +300,38 @@ def test_call_vmin_vmax(self, samples, vmin, vmax, expectation, cls): with expectation: bas._evaluate(samples) - @pytest.mark.parametrize("n_basis", [5, 6]) @pytest.mark.parametrize("vmin, vmax", [(0, 1), (-1, 1)]) @pytest.mark.parametrize("inp_num", [1, 2]) - def test_sklearn_clone_eval(self, cls, n_basis, vmin, vmax,inp_num): - bas = cls["eval"](n_basis, bounds=(vmin, vmax), **extra_decay_rates(cls["eval"], n_basis)) + def test_sklearn_clone_eval(self, cls, n_basis, vmin, vmax, inp_num): + bas = cls["eval"]( + n_basis, bounds=(vmin, vmax), **extra_decay_rates(cls["eval"], n_basis) + ) bas.set_input_shape(inp_num) bas2 = bas.__sklearn_clone__() assert id(bas) != id(bas2) - assert np.all(bas.__dict__.pop("decay_rates", True) == bas2.__dict__.pop("decay_rates", True)) + assert np.all( + bas.__dict__.pop("decay_rates", True) + == bas2.__dict__.pop("decay_rates", True) + ) assert bas.__dict__ == bas2.__dict__ - @pytest.mark.parametrize("n_basis", [5, 6]) @pytest.mark.parametrize("ws", [10, 20]) @pytest.mark.parametrize("inp_num", [1, 2]) - def test_sklearn_clone_conv(self, cls, n_basis, ws,inp_num): - bas = cls["conv"](n_basis, window_size=ws, **extra_decay_rates(cls["eval"], n_basis)) + def test_sklearn_clone_conv(self, cls, n_basis, ws, inp_num): + bas = cls["conv"]( + n_basis, window_size=ws, **extra_decay_rates(cls["eval"], n_basis) + ) bas.set_input_shape(inp_num) bas2 = bas.__sklearn_clone__() assert id(bas) != id(bas2) - assert np.all(bas.__dict__.pop("decay_rates", True) == bas2.__dict__.pop("decay_rates", True)) + assert np.all( + bas.__dict__.pop("decay_rates", True) + == bas2.__dict__.pop("decay_rates", True) + ) assert bas.__dict__ == bas2.__dict__ - @pytest.mark.parametrize( "attribute, value", [ @@ -1987,20 +1994,36 @@ def test_compute_features_input(self, eval_input): @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("inp_num", [1, 2]) - def test_sklearn_clone(self, basis_a, basis_b, n_basis_a, n_basis_b, vmin, vmax, inp_num, basis_class_specific_params): + def test_sklearn_clone( + self, + basis_a, + basis_b, + n_basis_a, + n_basis_b, + vmin, + vmax, + inp_num, + basis_class_specific_params, + ): """Recursively check cloning.""" basis_a_obj = self.instantiate_basis( n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) - basis_a_obj = basis_a_obj.set_input_shape(*([inp_num] * basis_a_obj._n_input_dimensionality)) + basis_a_obj = basis_a_obj.set_input_shape( + *([inp_num] * basis_a_obj._n_input_dimensionality) + ) basis_b_obj = self.instantiate_basis( n_basis_b, basis_b, basis_class_specific_params, window_size=15 ) - basis_b_obj = basis_b_obj.set_input_shape(*([inp_num]*basis_b_obj._n_input_dimensionality)) + basis_b_obj = basis_b_obj.set_input_shape( + *([inp_num] * basis_b_obj._n_input_dimensionality) + ) add = basis_a_obj + basis_b_obj def filter_attributes(obj, exclude_keys): - return {key: val for key, val in obj.__dict__.items() if key not in exclude_keys} + return { + key: val for key, val in obj.__dict__.items() if key not in exclude_keys + } def compare(b1, b2): assert id(b1) != id(b2) @@ -2009,8 +2032,12 @@ def compare(b1, b2): compare(b1.basis1, b2.basis1) compare(b1.basis2, b2.basis2) # add all params that are not parent or _basis1,_basis2 - d1 = filter_attributes(b1, exclude_keys=["_basis1", "_basis2", "_parent"]) - d2 = filter_attributes(b2, exclude_keys=["_basis1", "_basis2", "_parent"]) + d1 = filter_attributes( + b1, exclude_keys=["_basis1", "_basis2", "_parent"] + ) + d2 = filter_attributes( + b2, exclude_keys=["_basis1", "_basis2", "_parent"] + ) assert d1 == d2 else: decay_rates_b1 = b1.__dict__.get("_decay_rates", -1) @@ -2023,7 +2050,6 @@ def compare(b1, b2): add2 = add.__sklearn_clone__() compare(add, add2) - @pytest.mark.parametrize("n_basis_a", [5, 6]) @pytest.mark.parametrize("n_basis_b", [5, 6]) @pytest.mark.parametrize("sample_size", [10, 1000]) From fc23a3599540874bf5cb8de54cac620c090521c1 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 12 Dec 2024 22:15:00 -0500 Subject: [PATCH 21/37] merged stuff --- src/nemos/basis/_basis.py | 1 - tests/test_transformer_basis.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index 72aeeb22..4d8fd934 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -272,7 +272,6 @@ def setup_basis(self, *xi: ArrayLike) -> FeatureMatrix: pass @abc.abstractmethod -<<<<<<< HEAD def _set_input_independent_states(self): """ Compute all the basis states that do not depend on the input. diff --git a/tests/test_transformer_basis.py b/tests/test_transformer_basis.py index c301adc5..d31d6eb7 100644 --- a/tests/test_transformer_basis.py +++ b/tests/test_transformer_basis.py @@ -449,7 +449,7 @@ def test_transformerbasis_pickle( False, pytest.raises( RuntimeError, - match="Cannot initialize TransformerBasis: the provided basis has no defined input shape", + match="Cannot apply TransformerBasis: the provided basis has no defined input shape", ), ), ], From 6e1e9ef13f6d598fb13cc68949ee924b64a4c377 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 12 Dec 2024 22:37:35 -0500 Subject: [PATCH 22/37] fixed attr names --- src/nemos/basis/_basis.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index f6e7bd3e..8c892fd4 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -1252,18 +1252,18 @@ def _get_feature_slicing( _merge_slicing_dicts : Merges multiple slicing dictionaries, handling keys conflicts. """ # Set default values for n_inputs and start_slice if not provided - n_inputs = n_inputs or self._n_basis_input + n_inputs = n_inputs or self._n_basis_input_ start_slice = start_slice or 0 # If the instance is of AdditiveBasis type, handle slicing for the additive components split_dict, start_slice = self._basis1._get_feature_slicing( - n_inputs[: len(self._basis1._n_basis_input)], + n_inputs[: len(self._basis1._n_basis_input_)], start_slice, split_by_input=split_by_input, ) sp2, start_slice = self._basis2._get_feature_slicing( - n_inputs[len(self._basis1._n_basis_input) :], + n_inputs[len(self._basis1._n_basis_input_) :], start_slice, split_by_input=split_by_input, ) @@ -1317,12 +1317,7 @@ def __init__(self, basis1: Basis, basis2: Basis) -> None: self._n_input_dimensionality = ( basis1._n_input_dimensionality + basis2._n_input_dimensionality ) - BasisTransformerMixin.__init__(self) - self._n_basis_input = None - self._n_output_features = None - self._label = "(" + basis1.label + " * " + basis2.label + ")" - self._basis1 = basis1 - self._basis2 = basis2 + @property def n_basis_funcs(self): From ad1dc9a462ce19cb476d14d60ae63395262128a7 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 12 Dec 2024 22:58:12 -0500 Subject: [PATCH 23/37] ignore plot utils warns --- pyproject.toml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 4c6134ef..d20fd307 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,6 +123,11 @@ profile = "black" # Configure pytest [tool.pytest.ini_options] testpaths = ["tests"] # Specify the directory where test files are located +filterwarnings = [ + # note the use of single quote below to denote "raw" strings in TOML + 'ignore:plotting functions contained within:UserWarning', + 'ignore:Tolerance of \d\.\d+e-\d\d reached:RuntimeWarning', +] [tool.coverage.run] omit = [ From 52271c38322ad3006a1f5f2fcee4c819e964906a Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 13 Dec 2024 00:00:19 -0500 Subject: [PATCH 24/37] fixed docs --- .../basis/plot_01_1D_basis_function.md | 54 ++++++------------- 1 file changed, 15 insertions(+), 39 deletions(-) diff --git a/docs/background/basis/plot_01_1D_basis_function.md b/docs/background/basis/plot_01_1D_basis_function.md index f19903f7..7affd69e 100644 --- a/docs/background/basis/plot_01_1D_basis_function.md +++ b/docs/background/basis/plot_01_1D_basis_function.md @@ -50,7 +50,7 @@ glue_two_step_convolve() ## Defining a 1D Basis Object -We'll start by defining a 1D basis function object of the type [`MSplineEval`](nemos.basis.MSplineEval). +We'll start by defining a 1D basis function object of the type [`BSplineEval`](nemos.basis.BSplineEval). The hyperparameters needed to initialize this class are: - The number of basis functions, which should be a positive integer (required). @@ -85,52 +85,26 @@ plt.plot(x, y, lw=2) plt.title("B-Spline Basis") ``` -```{code-cell} ipython3 -:tags: [hide-input] - -<<<<<<< HEAD -# save image for thumbnail -from pathlib import Path -import os - -root = os.environ.get("READTHEDOCS_OUTPUT") -if root: - path = Path(root) / "html/_static/thumbnails/background" -# if local store in ../_build/html/... -else: - path = Path("../../_build/html/_static/thumbnails/background") - -# make sure the folder exists if run from build -if root or Path("../../_build/html/_static").exists(): - path.mkdir(parents=True, exist_ok=True) - -if path.exists(): - fig.savefig(path / "plot_01_1D_basis_function.svg") - - -print(path.resolve(), path.exists()) -``` ## Computing Features - All bases in the `nemos.basis` module perform a transformation of one or more time series into a set of features. This operation is always carried out by the method [`compute_features`](nemos.basis._basis.Basis.compute_features). -We can be group the bases into two categories depending on the type of transformation that [`compute_features`](nemos.basis._basis.Basis.compute_features) applies: +We can group the bases into two categories depending on the type of transformation that [`compute_features`](nemos.basis._basis.Basis.compute_features) applies: -1. **Evaluation Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to evaluate the basis directly, applying a non-linear transformation to the input. Classes in this category have names ends with "Eval," such as `BSplineEval`. +1. **Evaluation Bases**: These bases use `compute_features` to evaluate the basis directly, applying a non-linear transformation to the input. Classes in this category have names ending with "Eval," such as `BSplineEval`. -2. **Convolution Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to convolve the input with a kernel of basis elements, using a `window_size` specified by the user. Classes in this category have names ending with "Conv", such as `BSplineConv`. +2. **Convolution Bases**: These bases use `compute_features` to convolve the input with a kernel of basis elements, using a `window_size` specified by the user. Classes in this category have names ending with "Conv", such as `BSplineConv`. Let's see how these two categories operate: ```{code-cell} ipython3 -eval_mode = nmo.basis.MSplineEval(n_basis_funcs=n_basis) -conv_mode = nmo.basis.MSplineConv(n_basis_funcs=n_basis, window_size=100) +eval_mode = nmo.basis.BSplineEval(n_basis_funcs=n_basis) +conv_mode = nmo.basis.BSplineConv(n_basis_funcs=n_basis, window_size=100) # define an input angles = np.linspace(0, np.pi*4, 201) y = np.cos(angles) -# compute features in the two modalities +# compute features eval_feature = eval_mode.compute_features(y) conv_feature = conv_mode.compute_features(y) @@ -166,7 +140,6 @@ If you want to learn more about convolutions, as well as how and when to change check out the tutorial on [1D convolutions](convolution_background). ::: - ### Multi-dimensional inputs For inputs with more than one dimension, `compute_features` assumes the first axis represents samples. This is always valid for `pynapple` time series. For arrays, you can use [`numpy.transpose`](https://numpy.org/doc/stable/reference/generated/numpy.transpose.html) to re-arrange the axis if needed. @@ -212,11 +185,13 @@ Plotting the Basis Function Elements ------------------------------------ We suggest visualizing the basis post-instantiation by evaluating each element on a set of equi-spaced sample points and then plotting the result. The method [`Basis.evaluate_on_grid`](nemos.basis._basis.Basis.evaluate_on_grid) is designed for this, as it generates and returns -the equi-spaced samples along with the evaluated basis functions. The benefits of using Basis.evaluate_on_grid become -particularly evident when working with multidimensional basis functions. You can find more details and visual -background in the -[2D basis elements plotting section](plotting-2d-additive-basis-elements). +the equi-spaced samples along with the evaluated basis functions. +:::{admonition} Note + +The array returned by `evaluate_on_grid(n_samples)` is the same as the kernel that is used by the Conv bases initialized with `window_sizes=n_samples`! + +::: ```{code-cell} ipython3 # Call evaluate on grid on 100 sample points to generate samples and evaluate the basis at those samples @@ -230,12 +205,13 @@ plt.plot(equispaced_samples, eval_basis) plt.show() ``` +The benefits of using `evaluate_on_grid` become particularly evident when working with multidimensional basis functions. You can find more details in the [2D basis elements plotting section](plotting-2d-additive-basis-elements). ## Setting the basis support (Eval only) Sometimes, it is useful to restrict the basis to a fixed range. This can help manage outliers or ensure that your basis covers the same range across multiple experimental sessions. You can specify a range for the support of your basis by setting the `bounds` -parameter at initialization of "Eval" type basis (it doesn't make sense for convolutions). +parameter at initialization of Eval bases. Evaluating the basis at any sample outside the bounds will result in a NaN. From 9d6fbae248ab921bfe30e71e10be1dbadb1d9d90 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Sat, 14 Dec 2024 11:05:54 -0500 Subject: [PATCH 25/37] fixed rendering merge completed --- src/nemos/basis/_decaying_exponential.py | 3 +++ src/nemos/basis/_raised_cosine_basis.py | 5 +++++ src/nemos/basis/_spline_basis.py | 16 ++++++++++++++++ 3 files changed, 24 insertions(+) diff --git a/src/nemos/basis/_decaying_exponential.py b/src/nemos/basis/_decaying_exponential.py index a1fd4a24..5f80df58 100644 --- a/src/nemos/basis/_decaying_exponential.py +++ b/src/nemos/basis/_decaying_exponential.py @@ -22,6 +22,8 @@ class OrthExponentialBasis(Basis, AtomicBasisMixin, abc.ABC): Parameters ---------- + n_basis_funcs + Number of basis functions. decay_rates : Decay rates of the exponentials, shape ``(n_basis_funcs,)``. mode : @@ -34,6 +36,7 @@ class OrthExponentialBasis(Basis, AtomicBasisMixin, abc.ABC): def __init__( self, + n_basis_funcs: int, decay_rates: NDArray[np.floating], mode="eval", label: Optional[str] = "OrthExponentialBasis", diff --git a/src/nemos/basis/_raised_cosine_basis.py b/src/nemos/basis/_raised_cosine_basis.py index c964f1dc..dbf039eb 100644 --- a/src/nemos/basis/_raised_cosine_basis.py +++ b/src/nemos/basis/_raised_cosine_basis.py @@ -22,6 +22,8 @@ class RaisedCosineBasisLinear(Basis, AtomicBasisMixin, abc.ABC): Parameters ---------- + n_basis_funcs : + The number of basis functions. mode : The mode of operation. 'eval' for evaluation at sample points, 'conv' for convolutional operation. @@ -41,6 +43,7 @@ class RaisedCosineBasisLinear(Basis, AtomicBasisMixin, abc.ABC): def __init__( self, + n_basis_funcs: int, mode="eval", width: float = 2.0, label: Optional[str] = "RaisedCosineBasisLinear", @@ -232,6 +235,7 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear, abc.ABC): def __init__( self, + n_basis_funcs: int, mode="eval", width: float = 2.0, time_scaling: float = None, @@ -239,6 +243,7 @@ def __init__( label: Optional[str] = "RaisedCosineBasisLog", ) -> None: super().__init__( + n_basis_funcs, mode=mode, width=width, label=label, diff --git a/src/nemos/basis/_spline_basis.py b/src/nemos/basis/_spline_basis.py index 78cc34a6..c8f42d90 100644 --- a/src/nemos/basis/_spline_basis.py +++ b/src/nemos/basis/_spline_basis.py @@ -22,6 +22,8 @@ class SplineBasis(Basis, AtomicBasisMixin, abc.ABC): Parameters ---------- + n_basis_funcs : + Number of basis functions. mode : The mode of operation. 'eval' for evaluation at sample points, 'conv' for convolutional operation. @@ -39,6 +41,7 @@ class SplineBasis(Basis, AtomicBasisMixin, abc.ABC): def __init__( self, + n_basis_funcs: int, order: int = 2, label: Optional[str] = None, mode: Literal["conv", "eval"] = "eval", @@ -156,6 +159,9 @@ class MSplineBasis(SplineBasis, abc.ABC): Parameters ---------- + n_basis_funcs : + The number of basis functions to generate. More basis functions allow for + more flexible data modeling but can lead to overfitting. mode : The mode of operation. 'eval' for evaluation at sample points, 'conv' for convolutional operation. @@ -193,11 +199,13 @@ class MSplineBasis(SplineBasis, abc.ABC): def __init__( self, + n_basis_funcs: int, mode: Literal["eval", "conv"] = "eval", order: int = 2, label: Optional[str] = "MSplineEval", ) -> None: super().__init__( + n_basis_funcs, mode=mode, order=order, label=label, @@ -294,6 +302,8 @@ class BSplineBasis(SplineBasis, abc.ABC): Parameters ---------- + n_basis_funcs : + Number of basis functions. mode : The mode of operation. ``'eval'`` for evaluation at sample points, 'conv' for convolutional operation. @@ -319,11 +329,13 @@ class BSplineBasis(SplineBasis, abc.ABC): def __init__( self, + n_basis_funcs: int, mode="eval", order: int = 4, label: Optional[str] = "BSplineBasis", ): super().__init__( + n_basis_funcs, mode=mode, order=order, label=label, @@ -408,6 +420,8 @@ class CyclicBSplineBasis(SplineBasis, abc.ABC): Parameters ---------- + n_basis_funcs : + Number of basis functions. mode : The mode of operation. 'eval' for evaluation at sample points, 'conv' for convolutional operation. @@ -429,11 +443,13 @@ class CyclicBSplineBasis(SplineBasis, abc.ABC): def __init__( self, + n_basis_funcs: int, mode="eval", order: int = 4, label: Optional[str] = "CyclicBSplineBasis", ): super().__init__( + n_basis_funcs, mode=mode, order=order, label=label, From 95d0cd43505c0e53fa7d1643e78730b0cceaef10 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Sat, 14 Dec 2024 14:46:01 -0500 Subject: [PATCH 26/37] fixed doctests --- src/nemos/basis/_transformer_basis.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index db2d5676..850aed64 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -339,10 +339,10 @@ def __setattr__(self, name: str, value) -> None: >>> trans_bas.n_basis_funcs = 20 >>> # not allowed >>> try: - ... trans_bas.random_attribute_name = "some value" + ... trans_bas.rand_atrr = "some value" ... except ValueError as e: ... print(repr(e)) - ValueError('Only setting _basis or existing attributes of _basis is allowed.') + ValueError('Only setting _basis or existing attributes of _basis is allowed. Attempt to set `rand_atrr`.') """ # allow self._basis = basis and other attrs of self to be retrievable if name in ["_basis", "basis", "_wrapped_methods"]: From 4234c517c46d7a9a305bd63c8856672290a31e3b Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Sun, 15 Dec 2024 12:16:14 -0500 Subject: [PATCH 27/37] add a generator --- src/nemos/basis/_transformer_basis.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index 850aed64..5a76b23f 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -2,7 +2,7 @@ import copy from functools import wraps -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, Generator import numpy as np @@ -102,7 +102,7 @@ def basis(self): def basis(self, basis): self._basis = basis - def _unpack_inputs(self, X: FeatureMatrix) -> List: + def _unpack_inputs(self, X: FeatureMatrix) -> Generator: """Unpack inputs. Unpack horizontally stacked inputs using slicing. This works gracefully with ``pynapple``, @@ -120,13 +120,13 @@ def _unpack_inputs(self, X: FeatureMatrix) -> List: """ n_samples = X.shape[0] - out = [ + out = ( np.reshape(X[:, cc : cc + n_input], (n_samples, *bas._input_shape_)) for i, (bas, n_input) in enumerate( zip(self._list_components(), self._n_basis_input_) ) for cc in [sum(self._n_basis_input_[:i])] - ] + ) return out def fit(self, X: FeatureMatrix, y=None): From 32a6153431dad0f7fe7e6202d791f75fd10fb383 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Sun, 15 Dec 2024 12:49:46 -0500 Subject: [PATCH 28/37] change docstrings --- src/nemos/basis/_transformer_basis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index dfbf982b..fd5d3584 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -116,7 +116,7 @@ def _unpack_inputs(self, X: FeatureMatrix) -> Generator: Returns ------- : - A list of each individual input. + A generator looping on each individual input. """ n_samples = X.shape[0] From 99d9c4c6167b037969c9a7313305618aca183db1 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 16 Dec 2024 09:09:35 -0500 Subject: [PATCH 29/37] typo test name --- tests/test_transformer_basis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_transformer_basis.py b/tests/test_transformer_basis.py index d31d6eb7..1f9ae8dc 100644 --- a/tests/test_transformer_basis.py +++ b/tests/test_transformer_basis.py @@ -220,7 +220,7 @@ def test_transformerbasis_setattr_basis_attribute( "basis_cls", list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), ) -def test_transformerbasis_copy_basis_on_contsruct( +def test_transformerbasis_copy_basis_on_construct( basis_cls, basis_class_specific_params ): # modifying the transformerbasis's attributes shouldn't From 33fa2748395796f882649f2f10b51742cb9227d7 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 16 Dec 2024 13:52:12 -0500 Subject: [PATCH 30/37] added transformer tests --- src/nemos/basis/_transformer_basis.py | 5 +- tests/test_transformer_basis.py | 125 ++++++++++++++++++++++++++ 2 files changed, 128 insertions(+), 2 deletions(-) diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index fd5d3584..741b93b7 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -211,6 +211,7 @@ def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: >>> # Transform basis >>> feature_transformed = transformer.transform(X) """ + self._check_input(X, y) self._check_initialized(self._basis) # transpose does not work with pynapple # can't use func(*X.T) to unwrap @@ -413,7 +414,7 @@ def get_params(self, deep: bool = True) -> dict: def __dir__(self) -> list[str]: """Extend the list of properties of methods with the ones from the underlying Basis.""" - return list(super().__dir__()) + list(self._basis.__dir__()) + return list(set(list(super().__dir__()) + list(self._basis.__dir__()))) def __add__(self, other: TransformerBasis) -> TransformerBasis: """ @@ -508,5 +509,5 @@ def _check_input(self, X: FeatureMatrix, y=None): if y is not None and y.shape[0] != X.shape[0]: raise ValueError( "X and y must have the same number of samples. " - f"X has {X.shpae[0]} samples, while y has {y.shape[0]} samples." + f"X has {X.shape[0]} samples, while y has {y.shape[0]} samples." ) diff --git a/tests/test_transformer_basis.py b/tests/test_transformer_basis.py index 1f9ae8dc..a99d785a 100644 --- a/tests/test_transformer_basis.py +++ b/tests/test_transformer_basis.py @@ -3,12 +3,15 @@ import numpy as np import pytest + from conftest import CombinedBasis, list_all_basis_classes from sklearn.base import clone as sk_clone from sklearn.pipeline import Pipeline import nemos as nmo from nemos import basis +from nemos._inspect_utils import list_abstract_methods, get_subclass_methods +from nemos.basis import AdditiveBasis, MultiplicativeBasis @pytest.mark.parametrize( @@ -752,3 +755,125 @@ def test_transformer_in_pipeline(basis_cls, inp, basis_class_specific_params): pipe.fit(x, y) model.fit(X, y) np.testing.assert_allclose(pipe["glm"].coef_, model.coef_) + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_initialization(basis_cls, basis_class_specific_params): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + transformer = bas.to_transformer() + with pytest.raises(RuntimeError, match="Cannot apply TransformerBasis"): + transformer.fit(np.ones((100, ))) + + with pytest.raises(RuntimeError, match="Cannot apply TransformerBasis"): + transformer.transform(np.ones((100,))) + + with pytest.raises(RuntimeError, match="Cannot apply TransformerBasis"): + transformer.fit_transform(np.ones((100,))) + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_basis_setter(basis_cls, basis_class_specific_params): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + + bas2 = CombinedBasis().instantiate_basis(7, basis_cls, basis_class_specific_params, window_size=10) + transformer = bas.to_transformer() + transformer.basis = bas2 + assert transformer.basis.n_basis_funcs == bas2.n_basis_funcs + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_getstate(basis_cls, basis_class_specific_params): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + transformer = bas.to_transformer() + state = transformer.__getstate__() + assert {"_basis": transformer.basis} == state + + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_eetstate(basis_cls, basis_class_specific_params): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + bas2 = CombinedBasis().instantiate_basis(7, basis_cls, basis_class_specific_params, window_size=10) + transformer = bas.to_transformer() + state = {"_basis": bas2} + transformer.__setstate__(state) + assert transformer.basis == bas2 + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +def test_eetstate(basis_cls, basis_class_specific_params): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + transformer = bas.to_transformer() + lst = transformer.__dir__() + dict_abst_method = list_abstract_methods(nmo.basis._basis.Basis) + + # check it finds all abc basis methods + for meth in dict_abst_method: + assert meth[0] in lst + + # check all reimplemented methods + dict_reimplemented_method = get_subclass_methods(basis_cls) + for meth in dict_abst_method: + assert meth[0] in lst + + # check that it is a trnasformer + for meth in ["fit", "transform", "fit_transform"]: + assert meth in lst + +@pytest.mark.parametrize( + "basis_cls", + list_all_basis_classes(), +) +@pytest.mark.parametrize( + "inp, expectation", + [ + (np.random.randn(10, 2), pytest.raises(ValueError, match="Input mismatch: expected \d inputs")), + (np.random.randn(10, 3, 1), pytest.raises(ValueError, match="X must be 2-dimensional")), + ({1: np.random.randn(10, 3)}, pytest.raises(ValueError, match="The input must be a 2-dimensional array")), + (np.random.randn(10, 3), does_not_raise()), + ] +) +@pytest.mark.parametrize("method", ["fit", "transform", "fit_transform"]) +def test_check_input(inp, expectation, basis_cls, basis_class_specific_params, method): + bas = CombinedBasis().instantiate_basis( + 5, basis_cls, basis_class_specific_params, window_size=10 + ) + # set kernels + bas._set_input_independent_states() + # set input shape + transformer = bas.to_transformer().set_input_shape(*([3] * bas._n_input_dimensionality)) + if isinstance(bas, (AdditiveBasis, MultiplicativeBasis)): + if hasattr(inp, "ndim"): + ndim = inp.ndim + inp = np.concatenate([inp.reshape(inp.shape[0], -1)] * bas._n_input_dimensionality, axis=1) + if ndim == 3: + inp = inp[..., np.newaxis] + + meth = getattr(transformer, method) + + with expectation: + meth(inp) + with pytest.raises(ValueError, match="X and y must have the same"): + meth(inp, np.ones(11)) From ad88294ee2d77fd7491f84a9e91ebedd3b32b400 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 16 Dec 2024 14:24:52 -0500 Subject: [PATCH 31/37] improved coverage transformer basis --- src/nemos/basis/_transformer_basis.py | 2 +- tests/test_transformer_basis.py | 58 ++++++++++++++++++--------- 2 files changed, 40 insertions(+), 20 deletions(-) diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index 741b93b7..418c2580 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -211,8 +211,8 @@ def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: >>> # Transform basis >>> feature_transformed = transformer.transform(X) """ - self._check_input(X, y) self._check_initialized(self._basis) + self._check_input(X, y) # transpose does not work with pynapple # can't use func(*X.T) to unwrap return self._basis._compute_features(*self._unpack_inputs(X)) diff --git a/tests/test_transformer_basis.py b/tests/test_transformer_basis.py index a99d785a..79fa1848 100644 --- a/tests/test_transformer_basis.py +++ b/tests/test_transformer_basis.py @@ -3,15 +3,14 @@ import numpy as np import pytest - from conftest import CombinedBasis, list_all_basis_classes from sklearn.base import clone as sk_clone from sklearn.pipeline import Pipeline import nemos as nmo from nemos import basis -from nemos._inspect_utils import list_abstract_methods, get_subclass_methods -from nemos.basis import AdditiveBasis, MultiplicativeBasis +from nemos._inspect_utils import get_subclass_methods, list_abstract_methods +from nemos.basis import AdditiveBasis, MSplineConv, MultiplicativeBasis @pytest.mark.parametrize( @@ -699,12 +698,13 @@ def test_transformer_fit_transform_input_struct( @pytest.mark.parametrize( "inp", [ - np.random.randn( + 0.1 + * np.random.randn( 100, ), - np.random.randn(100, 1), - np.random.randn(100, 2), - np.random.randn(100, 1, 2), + 0.1 * np.random.randn(100, 1), + 0.1 * np.random.randn(100, 2), + 0.1 * np.random.randn(100, 1, 2), ], ) def test_transformer_in_pipeline(basis_cls, inp, basis_class_specific_params): @@ -756,6 +756,7 @@ def test_transformer_in_pipeline(basis_cls, inp, basis_class_specific_params): model.fit(X, y) np.testing.assert_allclose(pipe["glm"].coef_, model.coef_) + @pytest.mark.parametrize( "basis_cls", list_all_basis_classes(), @@ -766,7 +767,7 @@ def test_initialization(basis_cls, basis_class_specific_params): ) transformer = bas.to_transformer() with pytest.raises(RuntimeError, match="Cannot apply TransformerBasis"): - transformer.fit(np.ones((100, ))) + transformer.fit(np.ones((100,))) with pytest.raises(RuntimeError, match="Cannot apply TransformerBasis"): transformer.transform(np.ones((100,))) @@ -784,7 +785,9 @@ def test_basis_setter(basis_cls, basis_class_specific_params): 5, basis_cls, basis_class_specific_params, window_size=10 ) - bas2 = CombinedBasis().instantiate_basis(7, basis_cls, basis_class_specific_params, window_size=10) + bas2 = CombinedBasis().instantiate_basis( + 7, basis_cls, basis_class_specific_params, window_size=10 + ) transformer = bas.to_transformer() transformer.basis = bas2 assert transformer.basis.n_basis_funcs == bas2.n_basis_funcs @@ -811,17 +814,20 @@ def test_eetstate(basis_cls, basis_class_specific_params): bas = CombinedBasis().instantiate_basis( 5, basis_cls, basis_class_specific_params, window_size=10 ) - bas2 = CombinedBasis().instantiate_basis(7, basis_cls, basis_class_specific_params, window_size=10) + bas2 = CombinedBasis().instantiate_basis( + 7, basis_cls, basis_class_specific_params, window_size=10 + ) transformer = bas.to_transformer() state = {"_basis": bas2} transformer.__setstate__(state) - assert transformer.basis == bas2 + assert transformer.basis == bas2 + @pytest.mark.parametrize( "basis_cls", list_all_basis_classes(), ) -def test_eetstate(basis_cls, basis_class_specific_params): +def test_getstate(basis_cls, basis_class_specific_params): bas = CombinedBasis().instantiate_basis( 5, basis_cls, basis_class_specific_params, window_size=10 ) @@ -835,13 +841,14 @@ def test_eetstate(basis_cls, basis_class_specific_params): # check all reimplemented methods dict_reimplemented_method = get_subclass_methods(basis_cls) - for meth in dict_abst_method: + for meth in dict_reimplemented_method: assert meth[0] in lst # check that it is a trnasformer for meth in ["fit", "transform", "fit_transform"]: assert meth in lst + @pytest.mark.parametrize( "basis_cls", list_all_basis_classes(), @@ -849,11 +856,20 @@ def test_eetstate(basis_cls, basis_class_specific_params): @pytest.mark.parametrize( "inp, expectation", [ - (np.random.randn(10, 2), pytest.raises(ValueError, match="Input mismatch: expected \d inputs")), - (np.random.randn(10, 3, 1), pytest.raises(ValueError, match="X must be 2-dimensional")), - ({1: np.random.randn(10, 3)}, pytest.raises(ValueError, match="The input must be a 2-dimensional array")), + ( + np.random.randn(10, 2), + pytest.raises(ValueError, match="Input mismatch: expected \d inputs"), + ), + ( + np.random.randn(10, 3, 1), + pytest.raises(ValueError, match="X must be 2-dimensional"), + ), + ( + {1: np.random.randn(10, 3)}, + pytest.raises(ValueError, match="The input must be a 2-dimensional array"), + ), (np.random.randn(10, 3), does_not_raise()), - ] + ], ) @pytest.mark.parametrize("method", ["fit", "transform", "fit_transform"]) def test_check_input(inp, expectation, basis_cls, basis_class_specific_params, method): @@ -863,11 +879,15 @@ def test_check_input(inp, expectation, basis_cls, basis_class_specific_params, m # set kernels bas._set_input_independent_states() # set input shape - transformer = bas.to_transformer().set_input_shape(*([3] * bas._n_input_dimensionality)) + transformer = bas.to_transformer().set_input_shape( + *([3] * bas._n_input_dimensionality) + ) if isinstance(bas, (AdditiveBasis, MultiplicativeBasis)): if hasattr(inp, "ndim"): ndim = inp.ndim - inp = np.concatenate([inp.reshape(inp.shape[0], -1)] * bas._n_input_dimensionality, axis=1) + inp = np.concatenate( + [inp.reshape(inp.shape[0], -1)] * bas._n_input_dimensionality, axis=1 + ) if ndim == 3: inp = inp[..., np.newaxis] From 48889075118f20370bdb347909e82ea030951eeb Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Tue, 17 Dec 2024 16:17:57 -0500 Subject: [PATCH 32/37] typos --- src/nemos/basis/_transformer_basis.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index 418c2580..10a3cfcc 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -285,8 +285,8 @@ def __getattr__(self, name: str): """ Enable easy access to attributes of the underlying Basis object. - This method chaces all chainable methods (methods returning self) in a dicitonary. - These mehtods are created the first time they are accessed by decorating the `self._basis.name` + This method caches all chainable methods (methods returning self) in a dicitonary. + These methods are created the first time they are accessed by decorating the `self._basis.name` and cached for future use. Examples From f8708d5e6fb5fcf5a359a651e4fa180ef1deb048 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 18 Dec 2024 16:23:46 -0500 Subject: [PATCH 33/37] roll-back exception handling in tutorial --- docs/how_to_guide/plot_05_transformer_basis.md | 8 ++++++-- docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/how_to_guide/plot_05_transformer_basis.md b/docs/how_to_guide/plot_05_transformer_basis.md index aed91e54..5cdd10ee 100644 --- a/docs/how_to_guide/plot_05_transformer_basis.md +++ b/docs/how_to_guide/plot_05_transformer_basis.md @@ -102,7 +102,6 @@ As with any `sckit-learn` transformer, the `TransformerBasis` implements `fit`, At this point we have an object equipped with the correct methods, so now, all we have to do is concatenate the inputs into a unique array and call `fit_transform`, right? ```{code-cell} ipython3 -:tags: [raises-exception] # reinstantiate the basis transformer for illustration porpuses composite_basis = counts_basis + speed_basis @@ -110,7 +109,12 @@ trans_bas = (composite_basis).to_transformer() # concatenate the inputs inp = np.concatenate([counts, speed[:, np.newaxis]], axis=1) print(inp.shape) -trans_bas.fit_transform(inp) + +try: + trans_bas.fit_transform(inp) +except RuntimeError as e: + print(repr(e)) + ``` ...Unfortunately, not yet. The problem is that the basis doesn't know which columns of `inp` should be processed by `count_basis` and which by `speed_basis`. diff --git a/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md index 04d1092b..4073b928 100644 --- a/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md +++ b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md @@ -169,7 +169,7 @@ trans_bas.set_input_shape(1) ``` :::{admonition} Learn More about `TransformerBasis` -:note: +:class: note To learn more about `sklearn` transformers and `TransforerBasis`, check out [this note](tansformer-vs-nemos-basis). ::: From 6f5ddef2c608798bb9b8fd9584058a682f35e259 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 18 Dec 2024 17:05:59 -0500 Subject: [PATCH 34/37] fixed tests --- tests/test_transformer_basis.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/test_transformer_basis.py b/tests/test_transformer_basis.py index 79fa1848..dcdcd4bd 100644 --- a/tests/test_transformer_basis.py +++ b/tests/test_transformer_basis.py @@ -102,8 +102,8 @@ def test_basis_to_transformer_makes_a_copy(basis_cls, basis_class_specific_param # changing an attribute in bas should not change trans_bas if basis_cls in [basis.AdditiveBasis, basis.MultiplicativeBasis]: - bas_a._basis1.n_basis_funcs = 10 - assert trans_bas_a._basis._basis1.n_basis_funcs == 5 + bas_a.basis1.n_basis_funcs = 10 + assert trans_bas_a._basis.basis1.n_basis_funcs == 5 # changing an attribute in the transformer basis should not change the original bas_b = CombinedBasis().instantiate_basis( @@ -111,8 +111,8 @@ def test_basis_to_transformer_makes_a_copy(basis_cls, basis_class_specific_param ) bas_b.set_input_shape(*([1] * bas_b._n_input_dimensionality)) trans_bas_b = bas_b.to_transformer() - trans_bas_b._basis._basis1.n_basis_funcs = 100 - assert bas_b._basis1.n_basis_funcs == 5 + trans_bas_b._basis.basis1.n_basis_funcs = 100 + assert bas_b.basis1.n_basis_funcs == 5 else: bas_a.n_basis_funcs = 10 assert trans_bas_a.n_basis_funcs == 5 @@ -144,7 +144,7 @@ def test_transformerbasis_getattr( ) if basis_cls in [basis.AdditiveBasis, basis.MultiplicativeBasis]: for bas in [ - getattr(trans_basis._basis, attr) for attr in ("_basis1", "_basis2") + getattr(trans_basis._basis, attr) for attr in ("basis1", "basis2") ]: assert bas.n_basis_funcs == n_basis_funcs else: @@ -292,8 +292,8 @@ def test_transformerbasis_addition(basis_cls, basis_class_specific_params): == trans_bas_a._n_input_dimensionality + trans_bas_b._n_input_dimensionality ) if basis_cls not in [basis.AdditiveBasis, basis.MultiplicativeBasis]: - assert trans_bas_sum._basis1.n_basis_funcs == n_basis_funcs_a - assert trans_bas_sum._basis2.n_basis_funcs == n_basis_funcs_b + assert trans_bas_sum.basis1.n_basis_funcs == n_basis_funcs_a + assert trans_bas_sum.basis2.n_basis_funcs == n_basis_funcs_b @pytest.mark.parametrize( @@ -327,8 +327,8 @@ def test_transformerbasis_multiplication(basis_cls, basis_class_specific_params) == trans_bas_a._n_input_dimensionality + trans_bas_b._n_input_dimensionality ) if basis_cls not in [basis.AdditiveBasis, basis.MultiplicativeBasis]: - assert trans_bas_prod._basis1.n_basis_funcs == n_basis_funcs_a - assert trans_bas_prod._basis2.n_basis_funcs == n_basis_funcs_b + assert trans_bas_prod.basis1.n_basis_funcs == n_basis_funcs_a + assert trans_bas_prod.basis2.n_basis_funcs == n_basis_funcs_b @pytest.mark.parametrize( @@ -436,7 +436,7 @@ def test_transformerbasis_pickle( assert isinstance(trans_bas2, basis.TransformerBasis) if basis_cls in [basis.AdditiveBasis, basis.MultiplicativeBasis]: for bas in [ - getattr(trans_bas2._basis, attr) for attr in ("_basis1", "_basis2") + getattr(trans_bas2._basis, attr) for attr in ("basis1", "basis2") ]: assert bas.n_basis_funcs == n_basis_funcs else: @@ -739,7 +739,7 @@ def test_transformer_in_pipeline(basis_cls, inp, basis_class_specific_params): # set basis & refit if isinstance(bas, (basis.AdditiveBasis, basis.MultiplicativeBasis)): - pipe.set_params(bas__basis2__n_basis_funcs=4) + pipe.set_params(bas_basis2__n_basis_funcs=4) assert ( bas.basis2.n_basis_funcs == 5 ) # make sure that the change did not affect bas From 2ddbabed5e7ca1879270cbb5a3cd7776f9d8b8ae Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 18 Dec 2024 17:43:18 -0500 Subject: [PATCH 35/37] fixed tests --- tests/test_transformer_basis.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_transformer_basis.py b/tests/test_transformer_basis.py index dcdcd4bd..6ea3fbb1 100644 --- a/tests/test_transformer_basis.py +++ b/tests/test_transformer_basis.py @@ -714,7 +714,8 @@ def test_transformer_in_pipeline(basis_cls, inp, basis_class_specific_params): transformer = bas.set_input_shape( *([inp] * bas._n_input_dimensionality) ).to_transformer() - + if isinstance(bas, AdditiveBasis): + xxx=1 # fit outside pipeline X = bas.compute_features(*([inp] * bas._n_input_dimensionality)) log_mu = X.dot(0.005 * np.ones(X.shape[1])) @@ -739,7 +740,7 @@ def test_transformer_in_pipeline(basis_cls, inp, basis_class_specific_params): # set basis & refit if isinstance(bas, (basis.AdditiveBasis, basis.MultiplicativeBasis)): - pipe.set_params(bas_basis2__n_basis_funcs=4) + pipe.set_params(bas__basis2__n_basis_funcs=4) assert ( bas.basis2.n_basis_funcs == 5 ) # make sure that the change did not affect bas From 46327c34e8aeca03bf7a22bf828ce7ba0226bb00 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 18 Dec 2024 17:48:03 -0500 Subject: [PATCH 36/37] removed debug --- tests/test_transformer_basis.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_transformer_basis.py b/tests/test_transformer_basis.py index 6ea3fbb1..6dd7b247 100644 --- a/tests/test_transformer_basis.py +++ b/tests/test_transformer_basis.py @@ -714,8 +714,6 @@ def test_transformer_in_pipeline(basis_cls, inp, basis_class_specific_params): transformer = bas.set_input_shape( *([inp] * bas._n_input_dimensionality) ).to_transformer() - if isinstance(bas, AdditiveBasis): - xxx=1 # fit outside pipeline X = bas.compute_features(*([inp] * bas._n_input_dimensionality)) log_mu = X.dot(0.005 * np.ones(X.shape[1])) From 259f9e78bdf72b6edf7af10b9b86dcde53de85b1 Mon Sep 17 00:00:00 2001 From: Edoardo Balzani Date: Thu, 19 Dec 2024 16:18:25 -0500 Subject: [PATCH 37/37] Update src/nemos/basis/_transformer_basis.py Co-authored-by: William F. Broderick --- src/nemos/basis/_transformer_basis.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index 10a3cfcc..720adf08 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -153,8 +153,10 @@ def fit(self, X: FeatureMatrix, y=None): Raises ------ + RuntimeError + If ``self.n_basis_input`` is None. Call ``self.set_input_shape`` before calling ``fit`` to avoid this. ValueError: - If the number of columns in X do not match the number of inputs that the basis expects. + If the number of columns in X do not ``self.n_basis_input_``. Examples --------