Skip to content

Commit

Permalink
improved modularity of sklearn clone
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Dec 13, 2024
1 parent 47174ec commit fd95bcc
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 24 deletions.
25 changes: 1 addition & 24 deletions src/nemos/basis/_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,8 +625,7 @@ def _get_feature_slicing(
_get_default_slicing : Handles default slicing logic.
_merge_slicing_dicts : Merges multiple slicing dictionaries, handling keys conflicts.
"""
# Set default values for n_inputs and start_slice if not provided
n_inputs = n_inputs or self._n_basis_input_
# Set default values for start_slice if not provided
start_slice = start_slice or 0
# Handle the default case for non-additive basis types
# See overwritten method for recursion logic
Expand Down Expand Up @@ -816,28 +815,6 @@ class is accidentally removed.
)
return [self]

def __sklearn_clone__(self) -> Basis:
"""Clone the basis while preserving attributes related to input shapes.
This method ensures that input shape attributes (e.g., `_n_basis_input_`,
`_input_shape_`) are preserved during cloning. Reinitializing the class
as in the regular sklearn clone would drop these attributes, rendering
cross-validation unusable.
The method also handles recursive cloning for composite basis structures.
"""
# clone recursively
if hasattr(self, "_basis1") and hasattr(self, "_basis2"):
basis1 = self._basis1.__sklearn_clone__()
basis2 = self._basis2.__sklearn_clone__()
klass = self.__class__(basis1, basis2)

else:
klass = self.__class__(**self.get_params())

for attr_name in ["_n_basis_input_", "_input_shape_"]:
setattr(klass, attr_name, getattr(self, attr_name))
return klass


class AdditiveBasis(CompositeBasisMixin, Basis):
"""
Expand Down
96 changes: 96 additions & 0 deletions src/nemos/basis/_basis_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import copy
import inspect
import warnings
from functools import wraps
from typing import TYPE_CHECKING, Optional, Tuple, Union

import numpy as np
Expand All @@ -19,6 +20,52 @@
from ._basis import Basis


def set_input_shape_state(method):
"""
Decorator to preserve input shape-related attributes during method execution.
This decorator ensures that the attributes `_n_basis_input_` and `_input_shape_`
are copied from the original object (`self`) to the returned object (`klass`)
after the wrapped method executes. It is intended to be used with methods that
clone or create a new instance of the class, ensuring these critical attributes
are retained for functionality such as cross-validation.
Parameters
----------
method :
The method to be wrapped. This method is expected to return an object
(`klass`) that requires the `_n_basis_input_` and `_input_shape_` attributes.
Returns
-------
:
The wrapped method that copies `_n_basis_input_` and `_input_shape_` from
the original object (`self`) to the new object (`klass`).
Examples
--------
Applying the decorator to a method:
>>> from functools import wraps
>>> @set_input_shape_state
... def __sklearn_clone__(self):
... klass = self.__class__(**self.get_params())
... return klass
The `_n_basis_input_` and `_input_shape_` attributes of `self` will be
copied to `klass` after the method executes.
"""

@wraps(method)
def wrapper(self, *args, **kwargs):
klass: Basis = method(self, *args, **kwargs)
for attr_name in ["_n_basis_input_", "_input_shape_"]:
setattr(klass, attr_name, getattr(self, attr_name))
return klass

return wrapper


class EvalBasisMixin:
"""Mixin class for evaluational basis."""

Expand Down Expand Up @@ -118,6 +165,21 @@ def bounds(self, values: Union[None, Tuple[float, float]]):
f"Invalid bound {values}. Lower bound is greater or equal than the upper bound."
)

@set_input_shape_state
def __sklearn_clone__(self) -> Basis:
"""Clone the basis while preserving attributes related to input shapes.
This method ensures that input shape attributes (e.g., `_n_basis_input_`,
`_input_shape_`) are preserved during cloning. Reinitializing the class
as in the regular sklearn clone would drop these attributes, rendering
cross-validation unusable.
"""
klass = self.__class__(**self.get_params())

# for attr_name in ["_n_basis_input_", "_input_shape_"]:
# setattr(klass, attr_name, getattr(self, attr_name))
return klass


class ConvBasisMixin:
"""Mixin class for convolutional basis."""
Expand Down Expand Up @@ -309,6 +371,21 @@ def _check_has_kernel(self) -> None:
"You must call `_set_kernel` before `_compute_features` for Conv basis."
)

@set_input_shape_state
def __sklearn_clone__(self) -> Basis:
"""Clone the basis while preserving attributes related to input shapes.
This method ensures that input shape attributes (e.g., `_n_basis_input_`,
`_input_shape_`) are preserved during cloning. Reinitializing the class
as in the regular sklearn clone would drop these attributes, rendering
cross-validation unusable.
"""
klass = self.__class__(**self.get_params())

for attr_name in ["_n_basis_input_", "_input_shape_"]:
setattr(klass, attr_name, getattr(self, attr_name))
return klass


class BasisTransformerMixin:
"""Mixin class for constructing a transformer."""
Expand Down Expand Up @@ -462,3 +539,22 @@ def _list_components(self):
A list with all 1d basis components.
"""
return self._basis1._list_components() + self._basis2._list_components()

@set_input_shape_state
def __sklearn_clone__(self) -> Basis:
"""Clone the basis while preserving attributes related to input shapes.
This method ensures that input shape attributes (e.g., `_n_basis_input_`,
`_input_shape_`) are preserved during cloning. Reinitializing the class
as in the regular sklearn clone would drop these attributes, rendering
cross-validation unusable.
The method also handles recursive cloning for composite basis structures.
"""
# clone recursively
basis1 = self._basis1.__sklearn_clone__()
basis2 = self._basis2.__sklearn_clone__()
klass = self.__class__(basis1, basis2)

for attr_name in ["_n_basis_input_", "_input_shape_"]:
setattr(klass, attr_name, getattr(self, attr_name))
return klass

0 comments on commit fd95bcc

Please sign in to comment.