diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index e2f4a762..06af0f9f 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -133,10 +133,9 @@ class Basis(Base, abc.ABC, BasisTransformerMixin): def __init__( self, - mode: Literal["eval", "conv"] = "eval", + mode: Literal["eval", "conv", "composite"] = "eval", label: Optional[str] = None, ) -> None: - self._n_basis_funcs = getattr(self, "_n_basis_funcs", None) self._n_input_dimensionality = getattr(self, "_n_input_dimensionality", 0) self._mode = mode @@ -147,8 +146,8 @@ def __init__( self._label = str(label) # specified only after inputs/input shapes are provided - self._n_basis_input_ = getattr(self, "_n_basis_input_", None) - self._input_shape_ = getattr(self, "_input_shape_", None) + self._n_basis_input_ = None + self._input_shape_ = None # initialize parent to None. This should not end in "_" because it is # a permanent property of a basis, defined at composite basis init @@ -743,7 +742,7 @@ class AdditiveBasis(CompositeBasisMixin, Basis): def __init__(self, basis1: Basis, basis2: Basis) -> None: CompositeBasisMixin.__init__(self, basis1, basis2) - Basis.__init__(self, mode="eval") + Basis.__init__(self, mode="composite") self._label = "(" + basis1.label + " + " + basis2.label + ")" self._n_input_dimensionality = ( @@ -1154,7 +1153,7 @@ class MultiplicativeBasis(CompositeBasisMixin, Basis): def __init__(self, basis1: Basis, basis2: Basis) -> None: CompositeBasisMixin.__init__(self, basis1, basis2) - Basis.__init__(self, mode="eval") + Basis.__init__(self, mode="composite") self._label = "(" + basis1.label + " * " + basis2.label + ")" self._n_input_dimensionality = ( basis1._n_input_dimensionality + basis2._n_input_dimensionality diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index b708ec3a..3b40b4b6 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -5,7 +5,6 @@ import abc import copy import inspect -import warnings from functools import wraps from itertools import chain from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union @@ -21,7 +20,9 @@ from ._basis import Basis -def set_input_shape_state(method): +def set_input_shape_state( + method, states: Tuple[str] = ("_n_basis_input_", "_input_shape_") +): """ Decorator to preserve input shape-related attributes during method execution. @@ -36,6 +37,7 @@ def set_input_shape_state(method): method : The method to be wrapped. This method is expected to return an object (`klass`) that requires the `_n_basis_input_` and `_input_shape_` attributes. + attr_list Returns ------- @@ -60,7 +62,7 @@ def set_input_shape_state(method): @wraps(method) def wrapper(self, *args, **kwargs): klass: Basis = method(self, *args, **kwargs) - for attr_name in ["_n_basis_input_", "_input_shape_"]: + for attr_name in states: setattr(klass, attr_name, getattr(self, attr_name)) return klass @@ -84,9 +86,6 @@ def __sklearn_clone__(self) -> Basis: cross-validation unusable. """ klass = self.__class__(**self.get_params()) - - for attr_name in ["_n_basis_input_", "_input_shape_"]: - setattr(klass, attr_name, getattr(self, attr_name)) return klass def _iterate_over_components(self) -> Generator: @@ -519,12 +518,6 @@ def __init__(self, basis1: Basis, basis2: Basis): if all(set_bases): # pass down the input shapes self.set_input_shape(*shapes) - elif any(set_bases): - warnings.warn( - "Only some of the basis where initialized with `set_input_shape`, " - "please initialize the composite basis before computing features.", - category=UserWarning, - ) @property @abc.abstractmethod diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 2ea0bd86..8601a101 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -1830,19 +1830,3 @@ def _check_window_size(self, window_size: int): f"of basis functions. window_size is {window_size}, n_basis_funcs while" f"is {self.n_basis_funcs}." ) - - def set_kernel(self): - try: - super().set_kernel() - except ValueError as e: - if "OrthExponentialBasis requires at least as many" in str(e): - raise ValueError( - "Cannot set the kernels for OrthExponentialBasis when `window_size` is smaller " - "than `n_basis_funcs.\n" - "Please, increase the window size or reduce the number of basis functions. " - f"Current `window_size` is {self.window_size}, while `n_basis_funcs` is " - f"{self.n_basis_funcs}." - ) - else: - raise e - return self diff --git a/tests/test_basis.py b/tests/test_basis.py index 2cace735..e29b6cdd 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -2255,41 +2255,6 @@ def test_number_of_required_inputs_compute_features( with expectation: basis_obj.compute_features(*inputs) - @pytest.mark.parametrize("basis_a", list_all_basis_classes()) - @pytest.mark.parametrize("basis_b", list_all_basis_classes()) - @pytest.mark.parametrize("n_basis_a", [5]) - @pytest.mark.parametrize("n_basis_b", [6]) - @pytest.mark.parametrize("window_size", [10]) - def test_warn_partial_setup( - self, - n_basis_a, - n_basis_b, - basis_a, - basis_b, - window_size, - basis_class_specific_params, - ): - basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, basis_class_specific_params, window_size=window_size - ) - basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, basis_class_specific_params, window_size=window_size - ) - - basis_a_obj.set_input_shape(*([1] * basis_a_obj._n_input_dimensionality)) - with pytest.warns(UserWarning, match="Only some of the basis where"): - basis_a_obj + basis_b_obj - - # check that if both set addition is fine - basis_b_obj.set_input_shape(*([1] * basis_b_obj._n_input_dimensionality)) - basis_a_obj + basis_b_obj - - basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, basis_class_specific_params, window_size=window_size - ) - with pytest.warns(UserWarning, match="Only some of the basis where"): - basis_a_obj + basis_b_obj - @pytest.mark.parametrize("sample_size", [11, 20]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes())