Skip to content

Commit

Permalink
fixes from PR
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Dec 18, 2024
1 parent 0038600 commit efc138f
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 69 deletions.
11 changes: 5 additions & 6 deletions src/nemos/basis/_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,9 @@ class Basis(Base, abc.ABC, BasisTransformerMixin):

def __init__(
self,
mode: Literal["eval", "conv"] = "eval",
mode: Literal["eval", "conv", "composite"] = "eval",
label: Optional[str] = None,
) -> None:
self._n_basis_funcs = getattr(self, "_n_basis_funcs", None)
self._n_input_dimensionality = getattr(self, "_n_input_dimensionality", 0)

self._mode = mode
Expand All @@ -147,8 +146,8 @@ def __init__(
self._label = str(label)

# specified only after inputs/input shapes are provided
self._n_basis_input_ = getattr(self, "_n_basis_input_", None)
self._input_shape_ = getattr(self, "_input_shape_", None)
self._n_basis_input_ = None
self._input_shape_ = None

# initialize parent to None. This should not end in "_" because it is
# a permanent property of a basis, defined at composite basis init
Expand Down Expand Up @@ -743,7 +742,7 @@ class AdditiveBasis(CompositeBasisMixin, Basis):

def __init__(self, basis1: Basis, basis2: Basis) -> None:
CompositeBasisMixin.__init__(self, basis1, basis2)
Basis.__init__(self, mode="eval")
Basis.__init__(self, mode="composite")
self._label = "(" + basis1.label + " + " + basis2.label + ")"

self._n_input_dimensionality = (
Expand Down Expand Up @@ -1154,7 +1153,7 @@ class MultiplicativeBasis(CompositeBasisMixin, Basis):

def __init__(self, basis1: Basis, basis2: Basis) -> None:
CompositeBasisMixin.__init__(self, basis1, basis2)
Basis.__init__(self, mode="eval")
Basis.__init__(self, mode="composite")
self._label = "(" + basis1.label + " * " + basis2.label + ")"
self._n_input_dimensionality = (
basis1._n_input_dimensionality + basis2._n_input_dimensionality
Expand Down
17 changes: 5 additions & 12 deletions src/nemos/basis/_basis_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import abc
import copy
import inspect
import warnings
from functools import wraps
from itertools import chain
from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union
Expand All @@ -21,7 +20,9 @@
from ._basis import Basis


def set_input_shape_state(method):
def set_input_shape_state(
method, states: Tuple[str] = ("_n_basis_input_", "_input_shape_")
):
"""
Decorator to preserve input shape-related attributes during method execution.
Expand All @@ -36,6 +37,7 @@ def set_input_shape_state(method):
method :
The method to be wrapped. This method is expected to return an object
(`klass`) that requires the `_n_basis_input_` and `_input_shape_` attributes.
attr_list
Returns
-------
Expand All @@ -60,7 +62,7 @@ def set_input_shape_state(method):
@wraps(method)
def wrapper(self, *args, **kwargs):
klass: Basis = method(self, *args, **kwargs)
for attr_name in ["_n_basis_input_", "_input_shape_"]:
for attr_name in states:
setattr(klass, attr_name, getattr(self, attr_name))
return klass

Expand All @@ -84,9 +86,6 @@ def __sklearn_clone__(self) -> Basis:
cross-validation unusable.
"""
klass = self.__class__(**self.get_params())

for attr_name in ["_n_basis_input_", "_input_shape_"]:
setattr(klass, attr_name, getattr(self, attr_name))
return klass

def _iterate_over_components(self) -> Generator:
Expand Down Expand Up @@ -519,12 +518,6 @@ def __init__(self, basis1: Basis, basis2: Basis):
if all(set_bases):
# pass down the input shapes
self.set_input_shape(*shapes)
elif any(set_bases):
warnings.warn(
"Only some of the basis where initialized with `set_input_shape`, "
"please initialize the composite basis before computing features.",
category=UserWarning,
)

@property
@abc.abstractmethod
Expand Down
16 changes: 0 additions & 16 deletions src/nemos/basis/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1830,19 +1830,3 @@ def _check_window_size(self, window_size: int):
f"of basis functions. window_size is {window_size}, n_basis_funcs while"
f"is {self.n_basis_funcs}."
)

def set_kernel(self):
try:
super().set_kernel()
except ValueError as e:
if "OrthExponentialBasis requires at least as many" in str(e):
raise ValueError(
"Cannot set the kernels for OrthExponentialBasis when `window_size` is smaller "
"than `n_basis_funcs.\n"
"Please, increase the window size or reduce the number of basis functions. "
f"Current `window_size` is {self.window_size}, while `n_basis_funcs` is "
f"{self.n_basis_funcs}."
)
else:
raise e
return self
35 changes: 0 additions & 35 deletions tests/test_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2255,41 +2255,6 @@ def test_number_of_required_inputs_compute_features(
with expectation:
basis_obj.compute_features(*inputs)

@pytest.mark.parametrize("basis_a", list_all_basis_classes())
@pytest.mark.parametrize("basis_b", list_all_basis_classes())
@pytest.mark.parametrize("n_basis_a", [5])
@pytest.mark.parametrize("n_basis_b", [6])
@pytest.mark.parametrize("window_size", [10])
def test_warn_partial_setup(
self,
n_basis_a,
n_basis_b,
basis_a,
basis_b,
window_size,
basis_class_specific_params,
):
basis_a_obj = self.instantiate_basis(
n_basis_a, basis_a, basis_class_specific_params, window_size=window_size
)
basis_b_obj = self.instantiate_basis(
n_basis_b, basis_b, basis_class_specific_params, window_size=window_size
)

basis_a_obj.set_input_shape(*([1] * basis_a_obj._n_input_dimensionality))
with pytest.warns(UserWarning, match="Only some of the basis where"):
basis_a_obj + basis_b_obj

# check that if both set addition is fine
basis_b_obj.set_input_shape(*([1] * basis_b_obj._n_input_dimensionality))
basis_a_obj + basis_b_obj

basis_a_obj = self.instantiate_basis(
n_basis_a, basis_a, basis_class_specific_params, window_size=window_size
)
with pytest.warns(UserWarning, match="Only some of the basis where"):
basis_a_obj + basis_b_obj

@pytest.mark.parametrize("sample_size", [11, 20])
@pytest.mark.parametrize("basis_a", list_all_basis_classes())
@pytest.mark.parametrize("basis_b", list_all_basis_classes())
Expand Down

0 comments on commit efc138f

Please sign in to comment.