diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index 46d9f5a5..51538a06 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -625,8 +625,7 @@ def _get_feature_slicing( _get_default_slicing : Handles default slicing logic. _merge_slicing_dicts : Merges multiple slicing dictionaries, handling keys conflicts. """ - # Set default values for n_inputs and start_slice if not provided - n_inputs = n_inputs or self._n_basis_input_ + # Set default values for start_slice if not provided start_slice = start_slice or 0 # Handle the default case for non-additive basis types # See overwritten method for recursion logic @@ -816,28 +815,6 @@ class is accidentally removed. ) return [self] - def __sklearn_clone__(self) -> Basis: - """Clone the basis while preserving attributes related to input shapes. - - This method ensures that input shape attributes (e.g., `_n_basis_input_`, - `_input_shape_`) are preserved during cloning. Reinitializing the class - as in the regular sklearn clone would drop these attributes, rendering - cross-validation unusable. - The method also handles recursive cloning for composite basis structures. - """ - # clone recursively - if hasattr(self, "_basis1") and hasattr(self, "_basis2"): - basis1 = self._basis1.__sklearn_clone__() - basis2 = self._basis2.__sklearn_clone__() - klass = self.__class__(basis1, basis2) - - else: - klass = self.__class__(**self.get_params()) - - for attr_name in ["_n_basis_input_", "_input_shape_"]: - setattr(klass, attr_name, getattr(self, attr_name)) - return klass - class AdditiveBasis(CompositeBasisMixin, Basis): """ diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 9b208e31..e7413afd 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -6,6 +6,7 @@ import copy import inspect import warnings +from functools import wraps from typing import TYPE_CHECKING, Optional, Tuple, Union import numpy as np @@ -19,6 +20,52 @@ from ._basis import Basis +def set_input_shape_state(method): + """ + Decorator to preserve input shape-related attributes during method execution. + + This decorator ensures that the attributes `_n_basis_input_` and `_input_shape_` + are copied from the original object (`self`) to the returned object (`klass`) + after the wrapped method executes. It is intended to be used with methods that + clone or create a new instance of the class, ensuring these critical attributes + are retained for functionality such as cross-validation. + + Parameters + ---------- + method : + The method to be wrapped. This method is expected to return an object + (`klass`) that requires the `_n_basis_input_` and `_input_shape_` attributes. + + Returns + ------- + : + The wrapped method that copies `_n_basis_input_` and `_input_shape_` from + the original object (`self`) to the new object (`klass`). + + Examples + -------- + Applying the decorator to a method: + + >>> from functools import wraps + >>> @set_input_shape_state + ... def __sklearn_clone__(self): + ... klass = self.__class__(**self.get_params()) + ... return klass + + The `_n_basis_input_` and `_input_shape_` attributes of `self` will be + copied to `klass` after the method executes. + """ + + @wraps(method) + def wrapper(self, *args, **kwargs): + klass: Basis = method(self, *args, **kwargs) + for attr_name in ["_n_basis_input_", "_input_shape_"]: + setattr(klass, attr_name, getattr(self, attr_name)) + return klass + + return wrapper + + class EvalBasisMixin: """Mixin class for evaluational basis.""" @@ -118,6 +165,21 @@ def bounds(self, values: Union[None, Tuple[float, float]]): f"Invalid bound {values}. Lower bound is greater or equal than the upper bound." ) + @set_input_shape_state + def __sklearn_clone__(self) -> Basis: + """Clone the basis while preserving attributes related to input shapes. + + This method ensures that input shape attributes (e.g., `_n_basis_input_`, + `_input_shape_`) are preserved during cloning. Reinitializing the class + as in the regular sklearn clone would drop these attributes, rendering + cross-validation unusable. + """ + klass = self.__class__(**self.get_params()) + + # for attr_name in ["_n_basis_input_", "_input_shape_"]: + # setattr(klass, attr_name, getattr(self, attr_name)) + return klass + class ConvBasisMixin: """Mixin class for convolutional basis.""" @@ -309,6 +371,21 @@ def _check_has_kernel(self) -> None: "You must call `_set_kernel` before `_compute_features` for Conv basis." ) + @set_input_shape_state + def __sklearn_clone__(self) -> Basis: + """Clone the basis while preserving attributes related to input shapes. + + This method ensures that input shape attributes (e.g., `_n_basis_input_`, + `_input_shape_`) are preserved during cloning. Reinitializing the class + as in the regular sklearn clone would drop these attributes, rendering + cross-validation unusable. + """ + klass = self.__class__(**self.get_params()) + + for attr_name in ["_n_basis_input_", "_input_shape_"]: + setattr(klass, attr_name, getattr(self, attr_name)) + return klass + class BasisTransformerMixin: """Mixin class for constructing a transformer.""" @@ -462,3 +539,22 @@ def _list_components(self): A list with all 1d basis components. """ return self._basis1._list_components() + self._basis2._list_components() + + @set_input_shape_state + def __sklearn_clone__(self) -> Basis: + """Clone the basis while preserving attributes related to input shapes. + + This method ensures that input shape attributes (e.g., `_n_basis_input_`, + `_input_shape_`) are preserved during cloning. Reinitializing the class + as in the regular sklearn clone would drop these attributes, rendering + cross-validation unusable. + The method also handles recursive cloning for composite basis structures. + """ + # clone recursively + basis1 = self._basis1.__sklearn_clone__() + basis2 = self._basis2.__sklearn_clone__() + klass = self.__class__(basis1, basis2) + + for attr_name in ["_n_basis_input_", "_input_shape_"]: + setattr(klass, attr_name, getattr(self, attr_name)) + return klass