From ad45ea79fe91d1150fce231cb9457ec4922c1902 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 5 Dec 2024 15:34:09 -0500 Subject: [PATCH 01/41] added tests --- src/nemos/basis/_basis.py | 212 +++++++++++++------- src/nemos/basis/_basis_mixin.py | 39 +++- src/nemos/basis/basis.py | 335 ++++++++++++++++++++++++++++++++ tests/test_basis.py | 237 +++++++++++++++++++++- 4 files changed, 748 insertions(+), 75 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index 69020893..04fa02b4 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -16,7 +16,7 @@ from ..typing import FeatureMatrix from ..utils import row_wise_kron from ..validation import check_fraction_valid_samples -from ._basis_mixin import BasisTransformerMixin +from ._basis_mixin import BasisTransformerMixin, CompositeBasisMixin def add_docstring(method_name, cls): @@ -271,7 +271,9 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: Subclasses should implement how to handle the transformation specific to their basis function types and operation modes. """ - self._set_num_output_features(*xi) + if self._n_basis_input is None: + self.set_input_shape(*xi) + self._check_input_shape_consistency(*xi) self.set_kernel() return self._compute_features(*xi) @@ -752,9 +754,12 @@ def is_leaf(val): def _check_input_shape_consistency(self, x: NDArray): """Check input consistency across calls.""" - # remove sample axis + # remove sample axis and squeeze shape = x.shape[1:] - if self._input_shape is not None and self._input_shape != shape: + + initialized = self._input_shape is not None + is_shape_match = self._input_shape == shape + if initialized and not is_shape_match: expected_shape_str = "(n_samples, " + f"{self._input_shape}"[1:] expected_shape_str = expected_shape_str.replace(",)", ")") raise ValueError( @@ -768,60 +773,53 @@ def _check_input_shape_consistency(self, x: NDArray): "different shape, please create a new basis instance." ) - def _set_num_output_features(self, *xi: NDArray) -> Basis: + def set_input_shape(self, xi: int | tuple[int,...] | NDArray): """ - Pre-compute the number of inputs and output features. + Set the expected input shape for the basis object. - This function computes the number of inputs that are provided to the basis and uses - that number, and the n_basis_funcs to calculate the number of output features that - ``self.compute_features`` will return. These quantities and the input shape (excluding the sample axis) - are stored in ``self._n_basis_input`` and ``self._n_output_features``, and ``self._input_shape`` - respectively. + This method configures the shape of the input data that the basis object expects. + ``xi`` can be specified as an integer, a tuple of integers, or derived + from an array. The method also calculates the total number of input + features and output features based on the number of basis functions. Parameters ---------- - xi: - The input arrays. - - Returns - ------- - : - The basis itself, for chaining. + xi : + The input shape specification. + - An integer: Represents the dimensionality of the input. A value of ``1`` is treated as scalar input. + - A tuple: Represents the exact input shape excluding the first axis (sample axis). + All elements must be integers. + - An array: The shape is extracted, excluding the first axis (assumed to be the sample axis). Raises ------ - ValueError: - If the number of inputs do not match ``self._n_basis_input``, if ``self._n_basis_input`` was - not None. - - Notes - ----- - Once a ``compute_features`` is called, we enforce that for all subsequent calls of the method, - the input that the basis receives preserves the shape of all axes, except for the sample axis. - This condition guarantees the consistency of the feature axis, and therefore that - ``self.split_by_feature`` behaves appropriately. + ValueError + If a tuple is provided and it contains non-integer elements. + Returns + ------- + self : + Returns the instance itself to allow method chaining. """ - # Check that the input shape matches expectation - # Note that this method is reimplemented in AdditiveBasis and MultiplicativeBasis - # so we can assume that len(xi) == 1 - xi = xi[0] - self._check_input_shape_consistency(xi) + if isinstance(xi, tuple): + if not all(isinstance(i, int) for i in xi): + raise ValueError(f"The tuple provided contains non integer values. Tuple: {xi}.") + shape = xi + elif isinstance(xi, int): + shape = () if xi == 1 else (xi,) + else: + shape = xi.shape[1:] - # remove sample axis (samples are allowed to vary) - shape = xi.shape[1:] + n_inputs = (int(np.prod(shape)),) self._input_shape = shape - # remove sample axis & get the total input number - n_inputs = (1,) if xi.ndim == 1 else (np.prod(shape),) - self._n_basis_input = n_inputs self._n_output_features = self.n_basis_funcs * self._n_basis_input[0] return self -class AdditiveBasis(Basis): +class AdditiveBasis(CompositeBasisMixin, Basis): """ Class representing the addition of two Basis objects. @@ -866,13 +864,62 @@ def __init__(self, basis1: Basis, basis2: Basis) -> None: self._label = "(" + basis1.label + " + " + basis2.label + ")" self._basis1 = basis1 self._basis2 = basis2 + CompositeBasisMixin.__init__(self) + + def set_input_shape(self, *xi: int | tuple[int,...] | NDArray) -> Basis: + """ + Set the expected input shape for the basis object. + + This method sets the input shape for each component basis in the ``AdditiveBasis``. + One ``xi`` must be provided for each basis component, specified as an integer, + a tuple of integers, or an array. The method calculates and stores the total number of output features + based on the number of basis functions in each component and the provided input shapes. - def _set_num_output_features(self, *xi: NDArray) -> Basis: + Parameters + ---------- + *xi : + The input shape specifications. For every k, ``xi[k]`` can be: + - An integer: Represents the dimensionality of the input. A value of ``1`` is treated as scalar input. + - A tuple: Represents the exact input shape excluding the first axis (sample axis). + All elements must be integers. + - An array: The shape is extracted, excluding the first axis (assumed to be the sample axis). + + Raises + ------ + ValueError + If a tuple is provided and it contains non-integer elements. + + Returns + ------- + self : + Returns the instance itself to allow method chaining. + + Examples + -------- + >>> # Generate sample data + >>> import numpy as np + >>> import nemos as nmo + + >>> # define an additive basis + >>> basis_1 = nmo.basis.BSplineEval(5) + >>> basis_2 = nmo.basis.RaisedCosineLinearEval(6) + >>> basis_3 = nmo.basis.RaisedCosineLinearEval(7) + >>> additive_basis = basis_1 + basis_2 + basis_3 + + Specify the input shape using all 3 allowed ways: integer, tuple, array + >>> _ = additive_basis.set_input_shape(1, (2, 3), np.ones((10, 4, 5))) + + Expected output features are: + (5 bases * 1 input) + (6 bases * 6 inputs) + (7 bases * 20 inputs) = 181 + >>> additive_basis.n_output_features + 181 + + """ self._n_basis_input = ( - *self._basis1._set_num_output_features( + *self._basis1.set_input_shape( *xi[: self._basis1._n_input_dimensionality] )._n_basis_input, - *self._basis2._set_num_output_features( + *self._basis2.set_input_shape( *xi[self._basis1._n_input_dimensionality :] )._n_basis_input, ) @@ -881,9 +928,6 @@ def _set_num_output_features(self, *xi: NDArray) -> Basis: ) return self - def _check_n_basis_min(self) -> None: - pass - @support_pynapple(conv_type="numpy") @check_transform_input @check_one_dimensional @@ -975,25 +1019,6 @@ def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix: ) return X - def set_kernel(self, *xi: ArrayLike) -> Basis: - """Call fit on the added basis. - - If any of the added basis is in "conv" mode, it will prepare its kernels for the convolution. - - Parameters - ---------- - *xi: - The sample inputs. Unused, necessary to conform to ``scikit-learn`` API. - - Returns - ------- - : - The AdditiveBasis ready to be evaluated. - """ - self._basis1.set_kernel() - self._basis2.set_kernel() - return self - def split_by_feature( self, x: NDArray, @@ -1163,7 +1188,7 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: return super().evaluate_on_grid(*n_samples) -class MultiplicativeBasis(Basis): +class MultiplicativeBasis(CompositeBasisMixin, Basis): """ Class representing the multiplication (external product) of two Basis objects. @@ -1199,6 +1224,7 @@ class MultiplicativeBasis(Basis): def __init__(self, basis1: Basis, basis2: Basis) -> None: self.n_basis_funcs = basis1.n_basis_funcs * basis2.n_basis_funcs + CompositeBasisMixin.__init__(self) super().__init__(self.n_basis_funcs, mode="eval") self._n_input_dimensionality = ( basis1._n_input_dimensionality + basis2._n_input_dimensionality @@ -1209,9 +1235,8 @@ def __init__(self, basis1: Basis, basis2: Basis) -> None: self._basis1 = basis1 self._basis2 = basis2 BasisTransformerMixin.__init__(self) + CompositeBasisMixin.__init__(self) - def _check_n_basis_min(self) -> None: - pass def set_kernel(self, *xi: NDArray) -> Basis: """Call fit on the multiplied basis. @@ -1298,12 +1323,59 @@ def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix: ) return X - def _set_num_output_features(self, *xi: NDArray) -> Basis: + def set_input_shape(self, *xi: int | tuple[int,...] | NDArray) -> Basis: + """ + Set the expected input shape for the basis object. + + This method sets the input shape for each component basis in the ``MultiplicativeBasis``. + One ``xi`` must be provided for each basis component, specified as an integer, + a tuple of integers, or an array. The method calculates and stores the total number of output features + based on the number of basis functions in each component and the provided input shapes. + + Parameters + ---------- + *xi : + The input shape specifications. For every k,``xi[k]`` can be: + - An integer: Represents the dimensionality of the input. A value of ``1`` is treated as scalar input. + - A tuple: Represents the exact input shape excluding the first axis (sample axis). + All elements must be integers. + - An array: The shape is extracted, excluding the first axis (assumed to be the sample axis). + + Raises + ------ + ValueError + If a tuple is provided and it contains non-integer elements. + + Returns + ------- + self : + Returns the instance itself to allow method chaining. + Examples + -------- + >>> # Generate sample data + >>> import numpy as np + >>> import nemos as nmo + + >>> # define an additive basis + >>> basis_1 = nmo.basis.BSplineEval(5) + >>> basis_2 = nmo.basis.RaisedCosineLinearEval(6) + >>> basis_3 = nmo.basis.MSplineEval(7) + >>> multiplicative_basis = basis_1 * basis_2 * basis_3 + + Specify the input shape using all 3 allowed ways: integer, tuple, array + >>> _ = multiplicative_basis.set_input_shape(1, (2, 3), np.ones((10, 4, 5))) + + Expected output features are: + (5 * 6 * 7 bases) * (1 * 6 * 20 inputs) = 25200 + >>> multiplicative_basis.n_output_features + 25200 + + """ self._n_basis_input = ( - *self._basis1._set_num_output_features( + *self._basis1.set_input_shape( *xi[: self._basis1._n_input_dimensionality] )._n_basis_input, - *self._basis2._set_num_output_features( + *self._basis2.set_input_shape( *xi[self._basis1._n_input_dimensionality :] )._n_basis_input, ) diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index d30c850e..40670ab9 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -4,7 +4,7 @@ import copy import inspect -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, TYPE_CHECKING import numpy as np from numpy.typing import NDArray @@ -12,6 +12,8 @@ from ..convolve import create_convolutional_predictor from ._transformer_basis import TransformerBasis +if TYPE_CHECKING: + from ._basis import Basis class EvalBasisMixin: """Mixin class for evaluational basis.""" @@ -258,3 +260,38 @@ def to_transformer(self) -> TransformerBasis: >>> gridsearch = gridsearch.fit(X, y) """ return TransformerBasis(copy.deepcopy(self)) + + +class CompositeBasisMixin: + """Mixin class for composite basis. + + Add overwrites concrete methods or defines abstract methods for composite basis + (AdditiveBasis and MultiplicativeBasis). + """ + + def _check_n_basis_min(self) -> None: + pass + + def set_kernel(self, *xi: NDArray) -> Basis: + """Call set_kernel on the basis elements. + + If any of the basis elements is in "conv" mode, it will prepare its kernels for the convolution. + + Parameters + ---------- + *xi: + The sample inputs. Unused, necessary to conform to ``scikit-learn`` API. + + Returns + ------- + : + The basis ready to be evaluated. + """ + self._basis1.set_kernel() + self._basis2.set_kernel() + return self + + def _check_input_shape_consistency(self, *xi: NDArray): + """Check the input shape consistency for all basis elements.""" + self._basis1._check_input_shape_consistency(*xi[: self._basis1._n_input_dimensionality]) + self._basis2._check_input_shape_consistency(*xi[self._basis1._n_input_dimensionality:]) diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 730aaee9..b2ad0099 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -157,6 +157,35 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ return super().evaluate_on_grid(n_samples) + @add_docstring("set_input_shape", BSplineBasis) + def set_input_shape(self, xi: int | tuple[int,...] | NDArray): + """ + Examples + -------- + >>> import nemos as nmo + >>> import numpy as np + >>> basis = nmo.basis.BSplineEval(5) + + Configure with an integer input: + >>> _ = basis.set_input_shape(3) + >>> basis.n_output_features + 15 + + Configure with a tuple: + >>> _ = basis.set_input_shape((4, 5)) + >>> basis.n_output_features + 100 + + Configure with an array: + >>> x = np.ones((10, 4, 5)) + >>> _ = basis.set_input_shape(x) + >>> basis.n_output_features + 100 + + """ + return super().set_input_shape(xi) + + class BSplineConv(ConvBasisMixin, BSplineBasis): """ @@ -277,6 +306,34 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ return super().evaluate_on_grid(n_samples) + @add_docstring("set_input_shape", BSplineBasis) + def set_input_shape(self, xi: int | tuple[int,...] | NDArray): + """ + Examples + -------- + >>> import nemos as nmo + >>> import numpy as np + >>> basis = nmo.basis.BSplineConv(5, 10) + + Configure with an integer input: + >>> _ = basis.set_input_shape(3) + >>> basis.n_output_features + 15 + + Configure with a tuple: + >>> _ = basis.set_input_shape((4, 5)) + >>> basis.n_output_features + 100 + + Configure with an array: + >>> x = np.ones((10, 4, 5)) + >>> _ = basis.set_input_shape(x) + >>> basis.n_output_features + 100 + + """ + return super().set_input_shape(xi) + class CyclicBSplineEval(EvalBasisMixin, CyclicBSplineBasis): """ @@ -390,6 +447,34 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ return super().evaluate_on_grid(n_samples) + @add_docstring("set_input_shape", CyclicBSplineBasis) + def set_input_shape(self, xi: int | tuple[int,...] | NDArray): + """ + Examples + -------- + >>> import nemos as nmo + >>> import numpy as np + >>> basis = nmo.basis.CyclicBSplineEval(5) + + Configure with an integer input: + >>> _ = basis.set_input_shape(3) + >>> basis.n_output_features + 15 + + Configure with a tuple: + >>> _ = basis.set_input_shape((4, 5)) + >>> basis.n_output_features + 100 + + Configure with an array: + >>> x = np.ones((10, 4, 5)) + >>> _ = basis.set_input_shape(x) + >>> basis.n_output_features + 100 + + """ + return super().set_input_shape(xi) + class CyclicBSplineConv(ConvBasisMixin, CyclicBSplineBasis): """ @@ -502,6 +587,34 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ return super().evaluate_on_grid(n_samples) + @add_docstring("set_input_shape", CyclicBSplineBasis) + def set_input_shape(self, xi: int | tuple[int,...] | NDArray): + """ + Examples + -------- + >>> import nemos as nmo + >>> import numpy as np + >>> basis = nmo.basis.CyclicBSplineConv(5, 10) + + Configure with an integer input: + >>> _ = basis.set_input_shape(3) + >>> basis.n_output_features + 15 + + Configure with a tuple: + >>> _ = basis.set_input_shape((4, 5)) + >>> basis.n_output_features + 100 + + Configure with an array: + >>> x = np.ones((10, 4, 5)) + >>> _ = basis.set_input_shape(x) + >>> basis.n_output_features + 100 + + """ + return super().set_input_shape(xi) + class MSplineEval(EvalBasisMixin, MSplineBasis): r""" @@ -639,6 +752,34 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ return super().evaluate_on_grid(n_samples) + @add_docstring("set_input_shape", MSplineBasis) + def set_input_shape(self, xi: int | tuple[int,...] | NDArray): + """ + Examples + -------- + >>> import nemos as nmo + >>> import numpy as np + >>> basis = nmo.basis.MSplineEval(5) + + Configure with an integer input: + >>> _ = basis.set_input_shape(3) + >>> basis.n_output_features + 15 + + Configure with a tuple: + >>> _ = basis.set_input_shape((4, 5)) + >>> basis.n_output_features + 100 + + Configure with an array: + >>> x = np.ones((10, 4, 5)) + >>> _ = basis.set_input_shape(x) + >>> basis.n_output_features + 100 + + """ + return super().set_input_shape(xi) + class MSplineConv(ConvBasisMixin, MSplineBasis): r""" @@ -775,6 +916,34 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ return super().evaluate_on_grid(n_samples) + @add_docstring("set_input_shape", MSplineBasis) + def set_input_shape(self, xi: int | tuple[int,...] | NDArray): + """ + Examples + -------- + >>> import nemos as nmo + >>> import numpy as np + >>> basis = nmo.basis.MSplineConv(5, 10) + + Configure with an integer input: + >>> _ = basis.set_input_shape(3) + >>> basis.n_output_features + 15 + + Configure with a tuple: + >>> _ = basis.set_input_shape((4, 5)) + >>> basis.n_output_features + 100 + + Configure with an array: + >>> x = np.ones((10, 4, 5)) + >>> _ = basis.set_input_shape(x) + >>> basis.n_output_features + 100 + + """ + return super().set_input_shape(xi) + class RaisedCosineLinearEval( EvalBasisMixin, RaisedCosineBasisLinear, BasisTransformerMixin @@ -891,6 +1060,34 @@ def split_by_feature( """ return super().split_by_feature(x, axis=axis) + @add_docstring("set_input_shape", RaisedCosineBasisLinear) + def set_input_shape(self, xi: int | tuple[int,...] | NDArray): + """ + Examples + -------- + >>> import nemos as nmo + >>> import numpy as np + >>> basis = nmo.basis.RaisedCosineLinearEval(5) + + Configure with an integer input: + >>> _ = basis.set_input_shape(3) + >>> basis.n_output_features + 15 + + Configure with a tuple: + >>> _ = basis.set_input_shape((4, 5)) + >>> basis.n_output_features + 100 + + Configure with an array: + >>> x = np.ones((10, 4, 5)) + >>> _ = basis.set_input_shape(x) + >>> basis.n_output_features + 100 + + """ + return super().set_input_shape(xi) + class RaisedCosineLinearConv( ConvBasisMixin, RaisedCosineBasisLinear, BasisTransformerMixin @@ -1006,6 +1203,34 @@ def split_by_feature( """ return super().split_by_feature(x, axis=axis) + @add_docstring("set_input_shape", RaisedCosineBasisLinear) + def set_input_shape(self, xi: int | tuple[int,...] | NDArray): + """ + Examples + -------- + >>> import nemos as nmo + >>> import numpy as np + >>> basis = nmo.basis.RaisedCosineLinearConv(5, 10) + + Configure with an integer input: + >>> _ = basis.set_input_shape(3) + >>> basis.n_output_features + 15 + + Configure with a tuple: + >>> _ = basis.set_input_shape((4, 5)) + >>> basis.n_output_features + 100 + + Configure with an array: + >>> x = np.ones((10, 4, 5)) + >>> _ = basis.set_input_shape(x) + >>> basis.n_output_features + 100 + + """ + return super().set_input_shape(xi) + class RaisedCosineLogEval(EvalBasisMixin, RaisedCosineBasisLog): """Represent log-spaced raised cosine basis functions. @@ -1130,6 +1355,34 @@ def split_by_feature( """ return super().split_by_feature(x, axis=axis) + @add_docstring("set_input_shape", RaisedCosineBasisLog) + def set_input_shape(self, xi: int | tuple[int,...] | NDArray): + """ + Examples + -------- + >>> import nemos as nmo + >>> import numpy as np + >>> basis = nmo.basis.RaisedCosineLogEval(5) + + Configure with an integer input: + >>> _ = basis.set_input_shape(3) + >>> basis.n_output_features + 15 + + Configure with a tuple: + >>> _ = basis.set_input_shape((4, 5)) + >>> basis.n_output_features + 100 + + Configure with an array: + >>> x = np.ones((10, 4, 5)) + >>> _ = basis.set_input_shape(x) + >>> basis.n_output_features + 100 + + """ + return super().set_input_shape(xi) + class RaisedCosineLogConv(ConvBasisMixin, RaisedCosineBasisLog): """Represent log-spaced raised cosine basis functions. @@ -1255,6 +1508,33 @@ def split_by_feature( """ return super().split_by_feature(x, axis=axis) + @add_docstring("set_input_shape", RaisedCosineBasisLog) + def set_input_shape(self, xi: int | tuple[int,...] | NDArray): + """ + Examples + -------- + >>> import nemos as nmo + >>> import numpy as np + >>> basis = nmo.basis.RaisedCosineLogConv(5, 10) + + Configure with an integer input: + >>> _ = basis.set_input_shape(3) + >>> basis.n_output_features + 15 + + Configure with a tuple: + >>> _ = basis.set_input_shape((4, 5)) + >>> basis.n_output_features + 100 + + Configure with an array: + >>> x = np.ones((10, 4, 5)) + >>> _ = basis.set_input_shape(x) + >>> basis.n_output_features + 100 + + """ + return super().set_input_shape(xi) class OrthExponentialEval(EvalBasisMixin, OrthExponentialBasis): """Set of 1D basis decaying exponential functions numerically orthogonalized. @@ -1367,6 +1647,33 @@ def split_by_feature( """ return super().split_by_feature(x, axis=axis) + @add_docstring("set_input_shape", OrthExponentialBasis) + def set_input_shape(self, xi: int | tuple[int,...] | NDArray): + """ + Examples + -------- + >>> import nemos as nmo + >>> import numpy as np + >>> basis = nmo.basis.OrthExponentialEval(5, decay_rates=np.arange(1, 6)) + + Configure with an integer input: + >>> _ = basis.set_input_shape(3) + >>> basis.n_output_features + 15 + + Configure with a tuple: + >>> _ = basis.set_input_shape((4, 5)) + >>> basis.n_output_features + 100 + + Configure with an array: + >>> x = np.ones((10, 4, 5)) + >>> _ = basis.set_input_shape(x) + >>> basis.n_output_features + 100 + + """ + return super().set_input_shape(xi) class OrthExponentialConv(ConvBasisMixin, OrthExponentialBasis): """Set of 1D basis decaying exponential functions numerically orthogonalized. @@ -1476,3 +1783,31 @@ def split_by_feature( """ return super().split_by_feature(x, axis=axis) + + @add_docstring("set_input_shape", OrthExponentialBasis) + def set_input_shape(self, xi: int | tuple[int,...] | NDArray): + """ + Examples + -------- + >>> import nemos as nmo + >>> import numpy as np + >>> basis = nmo.basis.OrthExponentialConv(5, window_size=10, decay_rates=np.arange(1, 6)) + + Configure with an integer input: + >>> _ = basis.set_input_shape(3) + >>> basis.n_output_features + 15 + + Configure with a tuple: + >>> _ = basis.set_input_shape((4, 5)) + >>> basis.n_output_features + 100 + + Configure with an array: + >>> x = np.ones((10, 4, 5)) + >>> _ = basis.set_input_shape(x) + >>> basis.n_output_features + 100 + + """ + return super().set_input_shape(xi) diff --git a/tests/test_basis.py b/tests/test_basis.py index 002cfea4..193b69a1 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -130,6 +130,10 @@ def test_all_basis_are_tested() -> None: "split_by_feature", "Decompose an array along a specified axis into sub-arrays", ), + ( + "set_input_shape", + "Set the expected input shape for the basis object", + ), ], ) def test_example_docstrings_add( @@ -1254,6 +1258,44 @@ def test_transformer_get_params(self, cls): assert params_transf == params_basis assert np.all(rates_1 == rates_2) + @pytest.mark.parametrize( + "x, inp_shape, expectation", + [ + (np.ones((10,)), 1, does_not_raise()), + (np.ones((10, 1)), 1, pytest.raises(ValueError, match="Input shape mismatch detected")), + (np.ones((10, 2)), 2, does_not_raise()), + (np.ones((10, 1)), 2, pytest.raises(ValueError, match="Input shape mismatch detected")), + (np.ones((10, 2, 1)), 2, pytest.raises(ValueError, match="Input shape mismatch detected")), + (np.ones((10, 1, 2)), 2, pytest.raises(ValueError, match="Input shape mismatch detected")), + (np.ones((10, 1)), (1,), does_not_raise()), + (np.ones((10,)), tuple(), does_not_raise()), + (np.ones((10,)), np.zeros((12,)), does_not_raise()), + (np.ones((10,)), (1,), pytest.raises(ValueError, match="Input shape mismatch detected")), + (np.ones((10,1)), (), pytest.raises(ValueError, match="Input shape mismatch detected")), + (np.ones((10,1)), np.zeros((12, )), pytest.raises(ValueError, match="Input shape mismatch detected")), + (np.ones((10)), np.zeros((12, 1)), pytest.raises(ValueError, match="Input shape mismatch detected")), + + ] + ) + def test_input_shape_validity(self, x, inp_shape, expectation, cls): + bas = cls["eval"](n_basis_funcs=5, **extra_decay_rates(cls["eval"], 5)) + bas.set_input_shape(inp_shape) + with expectation: + bas.compute_features(x) + + @pytest.mark.parametrize( + "inp_shape, expectation", + [ + ((1, 1), does_not_raise()), + ((1, 1.), pytest.raises(ValueError, match="The tuple provided contains non integer")), + (np.ones((1, )), does_not_raise()), + (np.ones((1, 1)), does_not_raise()), + ] + ) + def test_set_input_value_types(self, inp_shape, expectation, cls): + bas = cls["eval"](n_basis_funcs=5, **extra_decay_rates(cls["eval"], 5)) + with expectation: + bas.set_input_shape(inp_shape) class TestRaisedCosineLogBasis(BasisFuncsTesting): cls = {"eval": basis.RaisedCosineLogEval, "conv": basis.RaisedCosineLogConv} @@ -2497,6 +2539,99 @@ def test_expected_input_number(self, n_input, expectation): with expectation: bas.compute_features(np.random.randn(30, 2), np.random.randn(30, n_input)) + @pytest.mark.parametrize("basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) + @pytest.mark.parametrize("basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) + @pytest.mark.parametrize("shape_a", [1, (), np.ones(3)]) + @pytest.mark.parametrize("shape_b", [1, (), np.ones(3)]) + @pytest.mark.parametrize("add_shape_a", [(), (1,)]) + @pytest.mark.parametrize("add_shape_b", [(), (1,)]) + def test_set_input_shape_type_1d_arrays(self, basis_a, basis_b, shape_a, shape_b, class_specific_params, add_shape_a, add_shape_b): + x = (np.ones((10, *add_shape_a)), np.ones((10, *add_shape_b))) + basis_a = self.instantiate_basis( + 5, basis_a, class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, class_specific_params, window_size=10 + ) + add = basis_a + basis_b + + add.set_input_shape(shape_a, shape_b) + if add_shape_a == () and add_shape_b == (): + expectation = does_not_raise() + else: + expectation = pytest.raises(ValueError, match="Input shape mismatch detected") + with expectation: + add.compute_features(*x) + + @pytest.mark.parametrize("basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) + @pytest.mark.parametrize("basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) + @pytest.mark.parametrize("shape_a", [2, (2,), np.ones((3, 2))]) + @pytest.mark.parametrize("shape_b", [3, (3,), np.ones((3, 3))]) + @pytest.mark.parametrize("add_shape_a", [(), (1,)]) + @pytest.mark.parametrize("add_shape_b", [(), (1,)]) + def test_set_input_shape_type_2d_arrays(self, basis_a, basis_b, shape_a, shape_b, class_specific_params, + add_shape_a, add_shape_b): + x = (np.ones((10, 2, *add_shape_a)), np.ones((10, 3, *add_shape_b))) + basis_a = self.instantiate_basis( + 5, basis_a, class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, class_specific_params, window_size=10 + ) + add = basis_a + basis_b + + add.set_input_shape(shape_a, shape_b) + if add_shape_a == () and add_shape_b == (): + expectation = does_not_raise() + else: + expectation = pytest.raises(ValueError, match="Input shape mismatch detected") + with expectation: + add.compute_features(*x) + + @pytest.mark.parametrize("basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) + @pytest.mark.parametrize("basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) + @pytest.mark.parametrize("shape_a", [(2, 2), np.ones((3, 2, 2))]) + @pytest.mark.parametrize("shape_b", [(3, 1), np.ones((3, 3, 1))]) + @pytest.mark.parametrize("add_shape_a", [(), (1,)]) + @pytest.mark.parametrize("add_shape_b", [(), (1,)]) + def test_set_input_shape_type_nd_arrays(self, basis_a, basis_b, shape_a, shape_b, class_specific_params,add_shape_a,add_shape_b): + x = (np.ones((10, 2, 2, *add_shape_a)), np.ones((10, 3, 1, *add_shape_b))) + basis_a = self.instantiate_basis( + 5, basis_a, class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, class_specific_params, window_size=10 + ) + add = basis_a + basis_b + + add.set_input_shape(shape_a, shape_b) + if add_shape_a == () and add_shape_b == (): + expectation = does_not_raise() + else: + expectation = pytest.raises(ValueError, match="Input shape mismatch detected") + with expectation: + add.compute_features(*x) + + @pytest.mark.parametrize( + "inp_shape, expectation", + [ + (((1, 1), (1, 1)), does_not_raise()), + (((1, 1.), (1, 1)), pytest.raises(ValueError, match="The tuple provided contains non integer")), + (((1, 1), (1, 1.)), pytest.raises(ValueError, match="The tuple provided contains non integer")), + ] + ) + @pytest.mark.parametrize("basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) + @pytest.mark.parametrize("basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) + def test_set_input_value_types(self, inp_shape, expectation, basis_a, basis_b, class_specific_params): + basis_a = self.instantiate_basis( + 5, basis_a, class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, class_specific_params, window_size=10 + ) + add = basis_a + basis_b + with expectation: + add.set_input_shape(*inp_shape) class TestMultiplicativeBasis(CombinedBasis): cls = {"eval": MultiplicativeBasis, "conv": MultiplicativeBasis} @@ -3149,6 +3284,100 @@ def test_n_basis_input(self, n_basis_input1, n_basis_input2): ) assert bas_prod.n_basis_input == (n_basis_input1, n_basis_input2) + @pytest.mark.parametrize("basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) + @pytest.mark.parametrize("basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) + @pytest.mark.parametrize("shape_a", [1, (), np.ones(3)]) + @pytest.mark.parametrize("shape_b", [1, (), np.ones(3)]) + @pytest.mark.parametrize("add_shape_a", [(), (1,)]) + @pytest.mark.parametrize("add_shape_b", [(), (1,)]) + def test_set_input_shape_type_1d_arrays(self, basis_a, basis_b, shape_a, shape_b, class_specific_params, add_shape_a, add_shape_b): + x = (np.ones((10, *add_shape_a)), np.ones((10, *add_shape_b))) + basis_a = self.instantiate_basis( + 5, basis_a, class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, class_specific_params, window_size=10 + ) + mul = basis_a * basis_b + + mul.set_input_shape(shape_a, shape_b) + if add_shape_a == () and add_shape_b == (): + expectation = does_not_raise() + else: + expectation = pytest.raises(ValueError, match="Input shape mismatch detected") + with expectation: + mul.compute_features(*x) + + @pytest.mark.parametrize("basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) + @pytest.mark.parametrize("basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) + @pytest.mark.parametrize("shape_a", [2, (2,), np.ones((3, 2))]) + @pytest.mark.parametrize("shape_b", [3, (3,), np.ones((3, 3))]) + @pytest.mark.parametrize("add_shape_a", [(), (1,)]) + @pytest.mark.parametrize("add_shape_b", [(), (1,)]) + def test_set_input_shape_type_2d_arrays(self, basis_a, basis_b, shape_a, shape_b, class_specific_params, + add_shape_a, add_shape_b): + x = (np.ones((10, 2, *add_shape_a)), np.ones((10, 3, *add_shape_b))) + basis_a = self.instantiate_basis( + 5, basis_a, class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, class_specific_params, window_size=10 + ) + mul = basis_a * basis_b + + mul.set_input_shape(shape_a, shape_b) + if add_shape_a == () and add_shape_b == (): + expectation = does_not_raise() + else: + expectation = pytest.raises(ValueError, match="Input shape mismatch detected") + with expectation: + mul.compute_features(*x) + + @pytest.mark.parametrize("basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) + @pytest.mark.parametrize("basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) + @pytest.mark.parametrize("shape_a", [(2, 2), np.ones((3, 2, 2))]) + @pytest.mark.parametrize("shape_b", [(3, 1), np.ones((3, 3, 1))]) + @pytest.mark.parametrize("add_shape_a", [(), (1,)]) + @pytest.mark.parametrize("add_shape_b", [(), (1,)]) + def test_set_input_shape_type_nd_arrays(self, basis_a, basis_b, shape_a, shape_b, class_specific_params,add_shape_a,add_shape_b): + x = (np.ones((10, 2, 2, *add_shape_a)), np.ones((10, 3, 1, *add_shape_b))) + basis_a = self.instantiate_basis( + 5, basis_a, class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, class_specific_params, window_size=10 + ) + mul = basis_a * basis_b + + mul.set_input_shape(shape_a, shape_b) + if add_shape_a == () and add_shape_b == (): + expectation = does_not_raise() + else: + expectation = pytest.raises(ValueError, match="Input shape mismatch detected") + with expectation: + mul.compute_features(*x) + + @pytest.mark.parametrize( + "inp_shape, expectation", + [ + (((1, 1), (1, 1)), does_not_raise()), + (((1, 1.), (1, 1)), pytest.raises(ValueError, match="The tuple provided contains non integer")), + (((1, 1), (1, 1.)), pytest.raises(ValueError, match="The tuple provided contains non integer")), + ] + ) + @pytest.mark.parametrize("basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) + @pytest.mark.parametrize("basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) + def test_set_input_value_types(self, inp_shape, expectation, basis_a, basis_b, class_specific_params): + basis_a = self.instantiate_basis( + 5, basis_a, class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, class_specific_params, window_size=10 + ) + mul = basis_a * basis_b + with expectation: + mul.set_input_shape(*inp_shape) + @pytest.mark.parametrize( "exponent", [-1, 0, 0.5, basis.RaisedCosineLogEval(4), 1, 2, 3] @@ -3843,7 +4072,7 @@ def test__get_splitter( bas23 = func2(bas3_instance) bas123 = func1(bas23) inps = [np.zeros((1, n)) if n > 1 else np.zeros((1,)) for n in n_input_basis] - bas123._set_num_output_features(*inps) + bas123.set_input_shape(*inps) splitter_dict, _ = bas123._get_feature_slicing(split_by_input=False) exp_slices = compute_slice(bas1_instance, bas2_instance, bas3_instance) assert exp_slices == splitter_dict @@ -3997,7 +4226,7 @@ def test__get_splitter_split_by_input( np.zeros((1, n)) if n > 1 else np.zeros((1,)) for n in (n_input_basis_1, n_input_basis_2) ] - bas12._set_num_output_features(*inps) + bas12.set_input_shape(*inps) splitter_dict, _ = bas12._get_feature_slicing() exp_slices = compute_slice(bas1_instance, bas2_instance) assert exp_slices == splitter_dict @@ -4028,7 +4257,7 @@ def test_duplicate_keys(bas1, bas2, bas3, class_specific_params): bas_obj = bas1_instance + bas2_instance + bas3_instance inps = [np.zeros((1,)) for n in range(3)] - bas_obj._set_num_output_features(*inps) + bas_obj.set_input_shape(*inps) slice_dict = bas_obj._get_feature_slicing()[0] assert tuple(slice_dict.keys()) == ("label", "label-1", "label-2") @@ -4073,7 +4302,7 @@ def test_split_feature_axis( ) bas = bas1_instance + bas2_instance - bas._set_num_output_features(np.zeros((1, 2)), np.zeros((1, 3))) + bas.set_input_shape(np.zeros((1, 2)), np.zeros((1, 3))) with expectation: out = bas.split_by_feature(x, axis=axis) for i, itm in enumerate(out.items()): From 24697c2dbdc2facac9cbd8b0bb6a3925667684b8 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 5 Dec 2024 15:35:08 -0500 Subject: [PATCH 02/41] linted --- src/nemos/basis/_basis.py | 11 +- src/nemos/basis/_basis_mixin.py | 11 +- src/nemos/basis/basis.py | 27 ++-- tests/test_basis.py | 268 +++++++++++++++++++++++++------- 4 files changed, 241 insertions(+), 76 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index 04fa02b4..fb9ee3cd 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -773,7 +773,7 @@ def _check_input_shape_consistency(self, x: NDArray): "different shape, please create a new basis instance." ) - def set_input_shape(self, xi: int | tuple[int,...] | NDArray): + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ Set the expected input shape for the basis object. @@ -803,7 +803,9 @@ def set_input_shape(self, xi: int | tuple[int,...] | NDArray): """ if isinstance(xi, tuple): if not all(isinstance(i, int) for i in xi): - raise ValueError(f"The tuple provided contains non integer values. Tuple: {xi}.") + raise ValueError( + f"The tuple provided contains non integer values. Tuple: {xi}." + ) shape = xi elif isinstance(xi, int): shape = () if xi == 1 else (xi,) @@ -866,7 +868,7 @@ def __init__(self, basis1: Basis, basis2: Basis) -> None: self._basis2 = basis2 CompositeBasisMixin.__init__(self) - def set_input_shape(self, *xi: int | tuple[int,...] | NDArray) -> Basis: + def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis: """ Set the expected input shape for the basis object. @@ -1237,7 +1239,6 @@ def __init__(self, basis1: Basis, basis2: Basis) -> None: BasisTransformerMixin.__init__(self) CompositeBasisMixin.__init__(self) - def set_kernel(self, *xi: NDArray) -> Basis: """Call fit on the multiplied basis. @@ -1323,7 +1324,7 @@ def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix: ) return X - def set_input_shape(self, *xi: int | tuple[int,...] | NDArray) -> Basis: + def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis: """ Set the expected input shape for the basis object. diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 40670ab9..9bd1b09c 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -4,7 +4,7 @@ import copy import inspect -from typing import Optional, Tuple, Union, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Tuple, Union import numpy as np from numpy.typing import NDArray @@ -15,6 +15,7 @@ if TYPE_CHECKING: from ._basis import Basis + class EvalBasisMixin: """Mixin class for evaluational basis.""" @@ -293,5 +294,9 @@ def set_kernel(self, *xi: NDArray) -> Basis: def _check_input_shape_consistency(self, *xi: NDArray): """Check the input shape consistency for all basis elements.""" - self._basis1._check_input_shape_consistency(*xi[: self._basis1._n_input_dimensionality]) - self._basis2._check_input_shape_consistency(*xi[self._basis1._n_input_dimensionality:]) + self._basis1._check_input_shape_consistency( + *xi[: self._basis1._n_input_dimensionality] + ) + self._basis2._check_input_shape_consistency( + *xi[self._basis1._n_input_dimensionality :] + ) diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index b2ad0099..88a57901 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -158,7 +158,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return super().evaluate_on_grid(n_samples) @add_docstring("set_input_shape", BSplineBasis) - def set_input_shape(self, xi: int | tuple[int,...] | NDArray): + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ Examples -------- @@ -186,7 +186,6 @@ def set_input_shape(self, xi: int | tuple[int,...] | NDArray): return super().set_input_shape(xi) - class BSplineConv(ConvBasisMixin, BSplineBasis): """ B-spline 1-dimensional basis functions. @@ -307,7 +306,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return super().evaluate_on_grid(n_samples) @add_docstring("set_input_shape", BSplineBasis) - def set_input_shape(self, xi: int | tuple[int,...] | NDArray): + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ Examples -------- @@ -448,7 +447,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return super().evaluate_on_grid(n_samples) @add_docstring("set_input_shape", CyclicBSplineBasis) - def set_input_shape(self, xi: int | tuple[int,...] | NDArray): + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ Examples -------- @@ -588,7 +587,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return super().evaluate_on_grid(n_samples) @add_docstring("set_input_shape", CyclicBSplineBasis) - def set_input_shape(self, xi: int | tuple[int,...] | NDArray): + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ Examples -------- @@ -753,7 +752,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return super().evaluate_on_grid(n_samples) @add_docstring("set_input_shape", MSplineBasis) - def set_input_shape(self, xi: int | tuple[int,...] | NDArray): + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ Examples -------- @@ -917,7 +916,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return super().evaluate_on_grid(n_samples) @add_docstring("set_input_shape", MSplineBasis) - def set_input_shape(self, xi: int | tuple[int,...] | NDArray): + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ Examples -------- @@ -1061,7 +1060,7 @@ def split_by_feature( return super().split_by_feature(x, axis=axis) @add_docstring("set_input_shape", RaisedCosineBasisLinear) - def set_input_shape(self, xi: int | tuple[int,...] | NDArray): + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ Examples -------- @@ -1204,7 +1203,7 @@ def split_by_feature( return super().split_by_feature(x, axis=axis) @add_docstring("set_input_shape", RaisedCosineBasisLinear) - def set_input_shape(self, xi: int | tuple[int,...] | NDArray): + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ Examples -------- @@ -1356,7 +1355,7 @@ def split_by_feature( return super().split_by_feature(x, axis=axis) @add_docstring("set_input_shape", RaisedCosineBasisLog) - def set_input_shape(self, xi: int | tuple[int,...] | NDArray): + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ Examples -------- @@ -1509,7 +1508,7 @@ def split_by_feature( return super().split_by_feature(x, axis=axis) @add_docstring("set_input_shape", RaisedCosineBasisLog) - def set_input_shape(self, xi: int | tuple[int,...] | NDArray): + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ Examples -------- @@ -1536,6 +1535,7 @@ def set_input_shape(self, xi: int | tuple[int,...] | NDArray): """ return super().set_input_shape(xi) + class OrthExponentialEval(EvalBasisMixin, OrthExponentialBasis): """Set of 1D basis decaying exponential functions numerically orthogonalized. @@ -1648,7 +1648,7 @@ def split_by_feature( return super().split_by_feature(x, axis=axis) @add_docstring("set_input_shape", OrthExponentialBasis) - def set_input_shape(self, xi: int | tuple[int,...] | NDArray): + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ Examples -------- @@ -1675,6 +1675,7 @@ def set_input_shape(self, xi: int | tuple[int,...] | NDArray): """ return super().set_input_shape(xi) + class OrthExponentialConv(ConvBasisMixin, OrthExponentialBasis): """Set of 1D basis decaying exponential functions numerically orthogonalized. @@ -1785,7 +1786,7 @@ def split_by_feature( return super().split_by_feature(x, axis=axis) @add_docstring("set_input_shape", OrthExponentialBasis) - def set_input_shape(self, xi: int | tuple[int,...] | NDArray): + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ Examples -------- diff --git a/tests/test_basis.py b/tests/test_basis.py index 193b69a1..6b7a235a 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -1262,20 +1262,51 @@ def test_transformer_get_params(self, cls): "x, inp_shape, expectation", [ (np.ones((10,)), 1, does_not_raise()), - (np.ones((10, 1)), 1, pytest.raises(ValueError, match="Input shape mismatch detected")), + ( + np.ones((10, 1)), + 1, + pytest.raises(ValueError, match="Input shape mismatch detected"), + ), (np.ones((10, 2)), 2, does_not_raise()), - (np.ones((10, 1)), 2, pytest.raises(ValueError, match="Input shape mismatch detected")), - (np.ones((10, 2, 1)), 2, pytest.raises(ValueError, match="Input shape mismatch detected")), - (np.ones((10, 1, 2)), 2, pytest.raises(ValueError, match="Input shape mismatch detected")), - (np.ones((10, 1)), (1,), does_not_raise()), + ( + np.ones((10, 1)), + 2, + pytest.raises(ValueError, match="Input shape mismatch detected"), + ), + ( + np.ones((10, 2, 1)), + 2, + pytest.raises(ValueError, match="Input shape mismatch detected"), + ), + ( + np.ones((10, 1, 2)), + 2, + pytest.raises(ValueError, match="Input shape mismatch detected"), + ), + (np.ones((10, 1)), (1,), does_not_raise()), (np.ones((10,)), tuple(), does_not_raise()), (np.ones((10,)), np.zeros((12,)), does_not_raise()), - (np.ones((10,)), (1,), pytest.raises(ValueError, match="Input shape mismatch detected")), - (np.ones((10,1)), (), pytest.raises(ValueError, match="Input shape mismatch detected")), - (np.ones((10,1)), np.zeros((12, )), pytest.raises(ValueError, match="Input shape mismatch detected")), - (np.ones((10)), np.zeros((12, 1)), pytest.raises(ValueError, match="Input shape mismatch detected")), - - ] + ( + np.ones((10,)), + (1,), + pytest.raises(ValueError, match="Input shape mismatch detected"), + ), + ( + np.ones((10, 1)), + (), + pytest.raises(ValueError, match="Input shape mismatch detected"), + ), + ( + np.ones((10, 1)), + np.zeros((12,)), + pytest.raises(ValueError, match="Input shape mismatch detected"), + ), + ( + np.ones((10)), + np.zeros((12, 1)), + pytest.raises(ValueError, match="Input shape mismatch detected"), + ), + ], ) def test_input_shape_validity(self, x, inp_shape, expectation, cls): bas = cls["eval"](n_basis_funcs=5, **extra_decay_rates(cls["eval"], 5)) @@ -1287,16 +1318,22 @@ def test_input_shape_validity(self, x, inp_shape, expectation, cls): "inp_shape, expectation", [ ((1, 1), does_not_raise()), - ((1, 1.), pytest.raises(ValueError, match="The tuple provided contains non integer")), - (np.ones((1, )), does_not_raise()), + ( + (1, 1.0), + pytest.raises( + ValueError, match="The tuple provided contains non integer" + ), + ), + (np.ones((1,)), does_not_raise()), (np.ones((1, 1)), does_not_raise()), - ] + ], ) def test_set_input_value_types(self, inp_shape, expectation, cls): bas = cls["eval"](n_basis_funcs=5, **extra_decay_rates(cls["eval"], 5)) - with expectation: + with expectation: bas.set_input_shape(inp_shape) + class TestRaisedCosineLogBasis(BasisFuncsTesting): cls = {"eval": basis.RaisedCosineLogEval, "conv": basis.RaisedCosineLogConv} @@ -2539,13 +2576,26 @@ def test_expected_input_number(self, n_input, expectation): with expectation: bas.compute_features(np.random.randn(30, 2), np.random.randn(30, n_input)) - @pytest.mark.parametrize("basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) - @pytest.mark.parametrize("basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) @pytest.mark.parametrize("shape_a", [1, (), np.ones(3)]) @pytest.mark.parametrize("shape_b", [1, (), np.ones(3)]) @pytest.mark.parametrize("add_shape_a", [(), (1,)]) @pytest.mark.parametrize("add_shape_b", [(), (1,)]) - def test_set_input_shape_type_1d_arrays(self, basis_a, basis_b, shape_a, shape_b, class_specific_params, add_shape_a, add_shape_b): + def test_set_input_shape_type_1d_arrays( + self, + basis_a, + basis_b, + shape_a, + shape_b, + class_specific_params, + add_shape_a, + add_shape_b, + ): x = (np.ones((10, *add_shape_a)), np.ones((10, *add_shape_b))) basis_a = self.instantiate_basis( 5, basis_a, class_specific_params, window_size=10 @@ -2559,18 +2609,32 @@ def test_set_input_shape_type_1d_arrays(self, basis_a, basis_b, shape_a, shape_b if add_shape_a == () and add_shape_b == (): expectation = does_not_raise() else: - expectation = pytest.raises(ValueError, match="Input shape mismatch detected") + expectation = pytest.raises( + ValueError, match="Input shape mismatch detected" + ) with expectation: add.compute_features(*x) - @pytest.mark.parametrize("basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) - @pytest.mark.parametrize("basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) @pytest.mark.parametrize("shape_a", [2, (2,), np.ones((3, 2))]) @pytest.mark.parametrize("shape_b", [3, (3,), np.ones((3, 3))]) @pytest.mark.parametrize("add_shape_a", [(), (1,)]) @pytest.mark.parametrize("add_shape_b", [(), (1,)]) - def test_set_input_shape_type_2d_arrays(self, basis_a, basis_b, shape_a, shape_b, class_specific_params, - add_shape_a, add_shape_b): + def test_set_input_shape_type_2d_arrays( + self, + basis_a, + basis_b, + shape_a, + shape_b, + class_specific_params, + add_shape_a, + add_shape_b, + ): x = (np.ones((10, 2, *add_shape_a)), np.ones((10, 3, *add_shape_b))) basis_a = self.instantiate_basis( 5, basis_a, class_specific_params, window_size=10 @@ -2584,17 +2648,32 @@ def test_set_input_shape_type_2d_arrays(self, basis_a, basis_b, shape_a, shape_b if add_shape_a == () and add_shape_b == (): expectation = does_not_raise() else: - expectation = pytest.raises(ValueError, match="Input shape mismatch detected") + expectation = pytest.raises( + ValueError, match="Input shape mismatch detected" + ) with expectation: add.compute_features(*x) - @pytest.mark.parametrize("basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) - @pytest.mark.parametrize("basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) @pytest.mark.parametrize("shape_a", [(2, 2), np.ones((3, 2, 2))]) @pytest.mark.parametrize("shape_b", [(3, 1), np.ones((3, 3, 1))]) @pytest.mark.parametrize("add_shape_a", [(), (1,)]) @pytest.mark.parametrize("add_shape_b", [(), (1,)]) - def test_set_input_shape_type_nd_arrays(self, basis_a, basis_b, shape_a, shape_b, class_specific_params,add_shape_a,add_shape_b): + def test_set_input_shape_type_nd_arrays( + self, + basis_a, + basis_b, + shape_a, + shape_b, + class_specific_params, + add_shape_a, + add_shape_b, + ): x = (np.ones((10, 2, 2, *add_shape_a)), np.ones((10, 3, 1, *add_shape_b))) basis_a = self.instantiate_basis( 5, basis_a, class_specific_params, window_size=10 @@ -2608,7 +2687,9 @@ def test_set_input_shape_type_nd_arrays(self, basis_a, basis_b, shape_a, shape_b if add_shape_a == () and add_shape_b == (): expectation = does_not_raise() else: - expectation = pytest.raises(ValueError, match="Input shape mismatch detected") + expectation = pytest.raises( + ValueError, match="Input shape mismatch detected" + ) with expectation: add.compute_features(*x) @@ -2616,13 +2697,29 @@ def test_set_input_shape_type_nd_arrays(self, basis_a, basis_b, shape_a, shape_b "inp_shape, expectation", [ (((1, 1), (1, 1)), does_not_raise()), - (((1, 1.), (1, 1)), pytest.raises(ValueError, match="The tuple provided contains non integer")), - (((1, 1), (1, 1.)), pytest.raises(ValueError, match="The tuple provided contains non integer")), - ] + ( + ((1, 1.0), (1, 1)), + pytest.raises( + ValueError, match="The tuple provided contains non integer" + ), + ), + ( + ((1, 1), (1, 1.0)), + pytest.raises( + ValueError, match="The tuple provided contains non integer" + ), + ), + ], + ) + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") ) - @pytest.mark.parametrize("basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) - @pytest.mark.parametrize("basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) - def test_set_input_value_types(self, inp_shape, expectation, basis_a, basis_b, class_specific_params): + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + def test_set_input_value_types( + self, inp_shape, expectation, basis_a, basis_b, class_specific_params + ): basis_a = self.instantiate_basis( 5, basis_a, class_specific_params, window_size=10 ) @@ -2630,9 +2727,10 @@ def test_set_input_value_types(self, inp_shape, expectation, basis_a, basis_b, c 5, basis_b, class_specific_params, window_size=10 ) add = basis_a + basis_b - with expectation: + with expectation: add.set_input_shape(*inp_shape) + class TestMultiplicativeBasis(CombinedBasis): cls = {"eval": MultiplicativeBasis, "conv": MultiplicativeBasis} @@ -3284,13 +3382,26 @@ def test_n_basis_input(self, n_basis_input1, n_basis_input2): ) assert bas_prod.n_basis_input == (n_basis_input1, n_basis_input2) - @pytest.mark.parametrize("basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) - @pytest.mark.parametrize("basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) @pytest.mark.parametrize("shape_a", [1, (), np.ones(3)]) @pytest.mark.parametrize("shape_b", [1, (), np.ones(3)]) @pytest.mark.parametrize("add_shape_a", [(), (1,)]) @pytest.mark.parametrize("add_shape_b", [(), (1,)]) - def test_set_input_shape_type_1d_arrays(self, basis_a, basis_b, shape_a, shape_b, class_specific_params, add_shape_a, add_shape_b): + def test_set_input_shape_type_1d_arrays( + self, + basis_a, + basis_b, + shape_a, + shape_b, + class_specific_params, + add_shape_a, + add_shape_b, + ): x = (np.ones((10, *add_shape_a)), np.ones((10, *add_shape_b))) basis_a = self.instantiate_basis( 5, basis_a, class_specific_params, window_size=10 @@ -3304,18 +3415,32 @@ def test_set_input_shape_type_1d_arrays(self, basis_a, basis_b, shape_a, shape_b if add_shape_a == () and add_shape_b == (): expectation = does_not_raise() else: - expectation = pytest.raises(ValueError, match="Input shape mismatch detected") + expectation = pytest.raises( + ValueError, match="Input shape mismatch detected" + ) with expectation: mul.compute_features(*x) - @pytest.mark.parametrize("basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) - @pytest.mark.parametrize("basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) @pytest.mark.parametrize("shape_a", [2, (2,), np.ones((3, 2))]) @pytest.mark.parametrize("shape_b", [3, (3,), np.ones((3, 3))]) @pytest.mark.parametrize("add_shape_a", [(), (1,)]) @pytest.mark.parametrize("add_shape_b", [(), (1,)]) - def test_set_input_shape_type_2d_arrays(self, basis_a, basis_b, shape_a, shape_b, class_specific_params, - add_shape_a, add_shape_b): + def test_set_input_shape_type_2d_arrays( + self, + basis_a, + basis_b, + shape_a, + shape_b, + class_specific_params, + add_shape_a, + add_shape_b, + ): x = (np.ones((10, 2, *add_shape_a)), np.ones((10, 3, *add_shape_b))) basis_a = self.instantiate_basis( 5, basis_a, class_specific_params, window_size=10 @@ -3329,17 +3454,32 @@ def test_set_input_shape_type_2d_arrays(self, basis_a, basis_b, shape_a, shape_b if add_shape_a == () and add_shape_b == (): expectation = does_not_raise() else: - expectation = pytest.raises(ValueError, match="Input shape mismatch detected") + expectation = pytest.raises( + ValueError, match="Input shape mismatch detected" + ) with expectation: mul.compute_features(*x) - @pytest.mark.parametrize("basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) - @pytest.mark.parametrize("basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) @pytest.mark.parametrize("shape_a", [(2, 2), np.ones((3, 2, 2))]) @pytest.mark.parametrize("shape_b", [(3, 1), np.ones((3, 3, 1))]) @pytest.mark.parametrize("add_shape_a", [(), (1,)]) @pytest.mark.parametrize("add_shape_b", [(), (1,)]) - def test_set_input_shape_type_nd_arrays(self, basis_a, basis_b, shape_a, shape_b, class_specific_params,add_shape_a,add_shape_b): + def test_set_input_shape_type_nd_arrays( + self, + basis_a, + basis_b, + shape_a, + shape_b, + class_specific_params, + add_shape_a, + add_shape_b, + ): x = (np.ones((10, 2, 2, *add_shape_a)), np.ones((10, 3, 1, *add_shape_b))) basis_a = self.instantiate_basis( 5, basis_a, class_specific_params, window_size=10 @@ -3353,7 +3493,9 @@ def test_set_input_shape_type_nd_arrays(self, basis_a, basis_b, shape_a, shape_b if add_shape_a == () and add_shape_b == (): expectation = does_not_raise() else: - expectation = pytest.raises(ValueError, match="Input shape mismatch detected") + expectation = pytest.raises( + ValueError, match="Input shape mismatch detected" + ) with expectation: mul.compute_features(*x) @@ -3361,13 +3503,29 @@ def test_set_input_shape_type_nd_arrays(self, basis_a, basis_b, shape_a, shape_b "inp_shape, expectation", [ (((1, 1), (1, 1)), does_not_raise()), - (((1, 1.), (1, 1)), pytest.raises(ValueError, match="The tuple provided contains non integer")), - (((1, 1), (1, 1.)), pytest.raises(ValueError, match="The tuple provided contains non integer")), - ] + ( + ((1, 1.0), (1, 1)), + pytest.raises( + ValueError, match="The tuple provided contains non integer" + ), + ), + ( + ((1, 1), (1, 1.0)), + pytest.raises( + ValueError, match="The tuple provided contains non integer" + ), + ), + ], + ) + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") ) - @pytest.mark.parametrize("basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) - @pytest.mark.parametrize("basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) - def test_set_input_value_types(self, inp_shape, expectation, basis_a, basis_b, class_specific_params): + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + def test_set_input_value_types( + self, inp_shape, expectation, basis_a, basis_b, class_specific_params + ): basis_a = self.instantiate_basis( 5, basis_a, class_specific_params, window_size=10 ) @@ -3375,7 +3533,7 @@ def test_set_input_value_types(self, inp_shape, expectation, basis_a, basis_b, c 5, basis_b, class_specific_params, window_size=10 ) mul = basis_a * basis_b - with expectation: + with expectation: mul.set_input_shape(*inp_shape) From ce5dff2bd9f2df1817b051e331e2b33c590ce630 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 12 Dec 2024 23:54:18 -0500 Subject: [PATCH 03/41] re-struct bases --- pyproject.toml | 5 + src/nemos/basis/_basis.py | 382 +++--- src/nemos/basis/_basis_mixin.py | 198 ++- src/nemos/basis/_decaying_exponential.py | 4 - src/nemos/basis/_raised_cosine_basis.py | 6 - src/nemos/basis/_spline_basis.py | 17 - src/nemos/basis/_transformer_basis.py | 51 +- src/nemos/basis/basis.py | 97 +- tests/conftest.py | 103 ++ tests/test_basis.py | 1426 +++++++++++----------- tests/test_pipeline.py | 17 +- 11 files changed, 1310 insertions(+), 996 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4c6134ef..d20fd307 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,6 +123,11 @@ profile = "black" # Configure pytest [tool.pytest.ini_options] testpaths = ["tests"] # Specify the directory where test files are located +filterwarnings = [ + # note the use of single quote below to denote "raw" strings in TOML + 'ignore:plotting functions contained within:UserWarning', + 'ignore:Tolerance of \d\.\d+e-\d\d reached:RuntimeWarning', +] [tool.coverage.run] omit = [ diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index d57b2ad2..adaa2c8a 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -53,7 +53,7 @@ def check_one_dimensional(func: Callable) -> Callable: """Check if the input is one-dimensional.""" @wraps(func) - def wrapper(self: Basis, *xi: ArrayLike, **kwargs): + def wrapper(self: Basis, *xi: NDArray, **kwargs): if any(x.ndim != 1 for x in xi): raise ValueError("Input sample must be one dimensional!") return func(self, *xi, **kwargs) @@ -135,28 +135,28 @@ class Basis(Base, abc.ABC, BasisTransformerMixin): def __init__( self, - n_basis_funcs: int, mode: Literal["eval", "conv"] = "eval", label: Optional[str] = None, ) -> None: - self.n_basis_funcs = n_basis_funcs - self._n_input_dimensionality = 0 + self._n_basis_funcs = getattr(self, "_n_basis_funcs", None) + self._n_input_dimensionality = getattr(self, "_n_input_dimensionality", 0) self._mode = mode - self._n_basis_input = None - - # these parameters are going to be set at the first call of `compute_features` - # since we cannot know a-priori how many features may be convolved - self._n_output_features = None - self._input_shape = None - if label is None: self._label = self.__class__.__name__ else: self._label = str(label) - self.kernel_ = None + self._check_n_basis_min() + + # specified only after inputs/input shapes are provided + self._n_basis_input_ = getattr(self, "_n_basis_input_", None) + self._input_shape_ = getattr(self, "_input_shape_", None) + + # initialize parent to None. This should not end in "_" because it is + # a permanent property of a basis, defined at composite basis init + self._parent = None @property def n_output_features(self) -> int | None: @@ -169,7 +169,9 @@ def n_output_features(self) -> int | None: provided to the basis is known. Therefore, before the first call to ``compute_features``, this property will return ``None``. After that call, ``n_output_features`` will be available. """ - return self._n_output_features + if self._n_basis_input_ is not None: + return self.n_basis_funcs * self._n_basis_input_[0] + return None @property def label(self) -> str: @@ -177,12 +179,12 @@ def label(self) -> str: return self._label @property - def n_basis_input(self) -> tuple | None: + def n_basis_input_(self) -> tuple | None: """Number of expected inputs. The number of inputs ``compute_feature`` expects. """ - return self._n_basis_input + return self._n_basis_input_ @property def n_basis_funcs(self): @@ -204,43 +206,6 @@ def mode(self): """Mode of operation, either ``"conv"`` or ``"eval"``.""" return self._mode - @staticmethod - def _apply_identifiability_constraints(X: NDArray): - """Apply identifiability constraints to a design matrix `X`. - - Removes columns from `X` until `[1, X]` is full rank to ensure the uniqueness - of the GLM (Generalized Linear Model) maximum-likelihood solution. This is particularly - crucial for models using bases like BSplines and CyclicBspline, which, due to their - construction, sum to 1 and can cause rank deficiency when combined with an intercept. - - For GLMs, this rank deficiency means that different sets of coefficients might yield - identical predicted rates and log-likelihood, complicating parameter learning, especially - in the absence of regularization. - - Parameters - ---------- - X: - The design matrix before applying the identifiability constraints. - - Returns - ------- - : - The adjusted design matrix with redundant columns dropped and columns mean-centered. - """ - - def add_constant(x): - return np.hstack((np.ones((x.shape[0], 1)), x)) - - rank = np.linalg.matrix_rank(add_constant(X)) - # mean center - X = X - np.nanmean(X, axis=0) - while rank < X.shape[1] + 1: - # drop a column - X = X[:, :-1] - # recompute rank - rank = np.linalg.matrix_rank(add_constant(X)) - return X - @check_transform_input def compute_features( self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor @@ -273,24 +238,112 @@ def compute_features( Subclasses should implement how to handle the transformation specific to their basis function types and operation modes. """ - if self._n_basis_input is None: + if self._n_basis_input_ is None: self.set_input_shape(*xi) self._check_input_shape_consistency(*xi) - self.set_kernel() + self._set_input_independent_states() return self._compute_features(*xi) @abc.abstractmethod def _compute_features( self, *xi: NDArray | Tsd | TsdFrame | TsdTensor ) -> FeatureMatrix: - """Convolve or evaluate the basis.""" + """Convolve or evaluate the basis. + + This method is intended to be equivalent to the sklearn transformer ``transform`` method. + As the latter, it computes the transformation assuming that all the states are already + pre-computed by ``_fit_basis``, a method corresponding to ``fit``. + + The method differs from transformer's ``transform`` for the structure of the input that it accepts. + In particular, ``_compute_features`` accepts a number of different time series, one per 1D basis component, + while ``transform`` requires all inputs to be concatenated in a single array. + """ + pass + + @abc.abstractmethod + def setup_basis(self, *xi: ArrayLike) -> FeatureMatrix: + """Pre-compute all basis state variables. + + This method is intended to be equivalent to the sklearn transformer ``fit`` method. + As the latter, it computes all the state attributes, and store it with the convention + that the attribute name **must** end with "_", for example ``self.kernel_``, + ``self._input_shape_``. + + The method differs from transformer's ``fit`` for the structure of the input that it accepts. + In particular, ``_fit_basis`` accepts a number of different time series, one per 1D basis component, + while ``fit`` requires all inputs to be concatenated in a single array. + """ pass @abc.abstractmethod - def set_kernel(self): - """Set kernel for conv basis and return self or just return self for eval.""" + def _set_input_independent_states(self): + """ + Compute all the basis states that do not depend on the input. + + An example of such state is the kernel_ for Conv baisis, which can be computed + without any input (it only depends on the basis type, the window size and the + number of basis elements). + """ pass + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): + """ + Set the expected input shape for the basis object. + + This method configures the shape of the input data that the basis object expects. + ``xi`` can be specified as an integer, a tuple of integers, or derived + from an array. The method also calculates the total number of input + features and output features based on the number of basis functions. + + Parameters + ---------- + xi : + The input shape specification. + - An integer: Represents the dimensionality of the input. A value of ``1`` is treated as scalar input. + - A tuple: Represents the exact input shape excluding the first axis (sample axis). + All elements must be integers. + - An array: The shape is extracted, excluding the first axis (assumed to be the sample axis). + + Raises + ------ + ValueError + If a tuple is provided and it contains non-integer elements. + + Returns + ------- + self : + Returns the instance itself to allow method chaining. + + Notes + ----- + All state attributes that depends on the input must be set in this method in order for + the API of basis to work correctly. In particular, this method is called by ``_basis_fit``, + which is equivalent to ``fit`` for a transformer. If any input dependent state + is not set in this method, then ``compute_features`` (equivalent to ``fit_transform``) will break. + + Separating states related to the input (settable with this method) and states that are unrelated + from the input (settable with ``set_kernel`` for Conv bases) is a deliberate design choice + that improves modularity. + + """ + if isinstance(xi, tuple): + if not all(isinstance(i, int) for i in xi): + raise ValueError( + f"The tuple provided contains non integer values. Tuple: {xi}." + ) + shape = xi + elif isinstance(xi, int): + shape = () if xi == 1 else (xi,) + else: + shape = xi.shape[1:] + + n_inputs = (int(np.prod(shape)),) + + self._input_shape_ = shape + + self._n_basis_input_ = n_inputs + return self + @abc.abstractmethod def _evaluate(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: """ @@ -381,13 +434,6 @@ def _check_transform_input( return xi - def _check_has_kernel(self) -> None: - """Check that the kernel is pre-computed.""" - if self.mode == "conv" and self.kernel_ is None: - raise ValueError( - "You must call `_set_kernel` before `_compute_features` when mode =`conv`." - ) - def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: """Evaluate the basis set on a grid of equi-spaced sample points. @@ -560,7 +606,7 @@ def _get_feature_slicing( Parameters ---------- n_inputs : - The number of input basis for each component, by default it uses ``self._n_basis_input``. + The number of input basis for each component, by default it uses ``self._n_basis_input_``. start_slice : The starting index for slicing, by default it starts from 0. split_by_input : @@ -582,9 +628,8 @@ def _get_feature_slicing( _merge_slicing_dicts : Merges multiple slicing dictionaries, handling keys conflicts. """ # Set default values for n_inputs and start_slice if not provided - n_inputs = n_inputs or self._n_basis_input + n_inputs = n_inputs or self._n_basis_input_ start_slice = start_slice or 0 - # Handle the default case for non-additive basis types # See overwritten method for recursion logic split_dict, start_slice = self._get_default_slicing( @@ -609,11 +654,9 @@ def _get_default_slicing( """Handle default slicing logic.""" if split_by_input: # should we remove this option? - if self._n_basis_input[0] == 1 or isinstance(self, MultiplicativeBasis): + if self._n_basis_input_[0] == 1 or isinstance(self, MultiplicativeBasis): split_dict = { - self.label: slice( - start_slice, start_slice + self._n_output_features - ) + self.label: slice(start_slice, start_slice + self.n_output_features) } else: split_dict = { @@ -622,14 +665,14 @@ def _get_default_slicing( start_slice + i * self.n_basis_funcs, start_slice + (i + 1) * self.n_basis_funcs, ) - for i in range(self._n_basis_input[0]) + for i in range(self._n_basis_input_[0]) } } else: split_dict = { - self.label: slice(start_slice, start_slice + self._n_output_features) + self.label: slice(start_slice, start_slice + self.n_output_features) } - start_slice += self._n_output_features + start_slice += self.n_output_features return split_dict, start_slice def split_by_feature( @@ -721,13 +764,13 @@ def is_leaf(val): # Apply the slicing using the custom leaf function out = jax.tree_util.tree_map(lambda sl: x[sl], index_dict, is_leaf=is_leaf) - # reshape the arrays to spilt by n_basis_input + # reshape the arrays to spilt by n_basis_input_ reshaped_out = dict() for i, vals in enumerate(out.items()): key, val = vals shape = list(val.shape) reshaped_out[key] = val.reshape( - shape[:axis] + [self._n_basis_input[i], -1] + shape[axis + 1 :] + shape[:axis] + [self._n_basis_input_[i], -1] + shape[axis + 1 :] ) return reshaped_out @@ -736,10 +779,10 @@ def _check_input_shape_consistency(self, x: NDArray): # remove sample axis and squeeze shape = x.shape[1:] - initialized = self._input_shape is not None - is_shape_match = self._input_shape == shape + initialized = self._input_shape_ is not None + is_shape_match = self._input_shape_ == shape if initialized and not is_shape_match: - expected_shape_str = "(n_samples, " + f"{self._input_shape}"[1:] + expected_shape_str = "(n_samples, " + f"{self._input_shape_}"[1:] expected_shape_str = expected_shape_str.replace(",)", ")") raise ValueError( f"Input shape mismatch detected.\n\n" @@ -752,52 +795,50 @@ def _check_input_shape_consistency(self, x: NDArray): "different shape, please create a new basis instance." ) - def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): - """ - Set the expected input shape for the basis object. + def _list_components(self): + """List all basis components. - This method configures the shape of the input data that the basis object expects. - ``xi`` can be specified as an integer, a tuple of integers, or derived - from an array. The method also calculates the total number of input - features and output features based on the number of basis functions. + This is re-implemented for composite basis in the mixin class. - Parameters - ---------- - xi : - The input shape specification. - - An integer: Represents the dimensionality of the input. A value of ``1`` is treated as scalar input. - - A tuple: Represents the exact input shape excluding the first axis (sample axis). - All elements must be integers. - - An array: The shape is extracted, excluding the first axis (assumed to be the sample axis). + Returns + ------- + A list with all 1d basis components. Raises ------ - ValueError - If a tuple is provided and it contains non-integer elements. - - Returns - ------- - self : - Returns the instance itself to allow method chaining. + RuntimeError + If the basis has multiple components. This would only happen if there is an + implementation issue, for example, if a composite basis is implemented but the + mixin class is not initialized, or if the _list_components method of the composite mixin + class is accidentally removed. """ - if isinstance(xi, tuple): - if not all(isinstance(i, int) for i in xi): - raise ValueError( - f"The tuple provided contains non integer values. Tuple: {xi}." - ) - shape = xi - elif isinstance(xi, int): - shape = () if xi == 1 else (xi,) - else: - shape = xi.shape[1:] + if hasattr(self, "basis1"): + raise RuntimeError( + "Composite basis must implement the _list_components method." + ) + return [self] - n_inputs = (int(np.prod(shape)),) + def __sklearn_clone__(self) -> Basis: + """Clone the basis while preserving attributes related to input shapes. - self._input_shape = shape + This method ensures that input shape attributes (e.g., `_n_basis_input_`, + `_input_shape_`) are preserved during cloning. Reinitializing the class + as in the regular sklearn clone would drop these attributes, rendering + cross-validation unusable. + The method also handles recursive cloning for composite basis structures. + """ + # clone recursively + if hasattr(self, "_basis1") and hasattr(self, "_basis2"): + basis1 = self._basis1.__sklearn_clone__() + basis2 = self._basis2.__sklearn_clone__() + klass = self.__class__(basis1, basis2) - self._n_basis_input = n_inputs - self._n_output_features = self.n_basis_funcs * self._n_basis_input[0] - return self + else: + klass = self.__class__(**self.get_params()) + + for attr_name in ["_n_basis_input_", "_input_shape_"]: + setattr(klass, attr_name, getattr(self, attr_name)) + return klass class AdditiveBasis(CompositeBasisMixin, Basis): @@ -811,11 +852,6 @@ class AdditiveBasis(CompositeBasisMixin, Basis): basis2 : Second basis object to add. - Attributes - ---------- - n_basis_funcs : int - Number of basis functions. - Examples -------- >>> # Generate sample data @@ -835,17 +871,31 @@ class AdditiveBasis(CompositeBasisMixin, Basis): """ def __init__(self, basis1: Basis, basis2: Basis) -> None: - self.n_basis_funcs = basis1.n_basis_funcs + basis2.n_basis_funcs - super().__init__(self.n_basis_funcs, mode="eval") + CompositeBasisMixin.__init__(self, basis1, basis2) + Basis.__init__(self, mode="eval") + self._label = "(" + basis1.label + " + " + basis2.label + ")" + self._n_input_dimensionality = ( basis1._n_input_dimensionality + basis2._n_input_dimensionality ) - self._n_basis_input = None - self._n_output_features = None - self._label = "(" + basis1.label + " + " + basis2.label + ")" - self._basis1 = basis1 - self._basis2 = basis2 - CompositeBasisMixin.__init__(self) + + @property + def n_basis_funcs(self): + """Compute the n-basis function runtime. + + This plays well with cross-validation where the number of basis function of the + underlying bases can be changed. It must be read-only since the number of basis + is determined by the two basis elements and the type of composition. + """ + return self.basis1.n_basis_funcs + self.basis2.n_basis_funcs + + @property + def n_output_features(self): + out1 = getattr(self._basis1, "n_output_features", None) + out2 = getattr(self._basis2, "n_output_features", None) + if out1 is None or out2 is None: + return None + return out1 + out2 def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis: """ @@ -896,16 +946,13 @@ def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis: 181 """ - self._n_basis_input = ( + self._n_basis_input_ = ( *self._basis1.set_input_shape( *xi[: self._basis1._n_input_dimensionality] - )._n_basis_input, + )._n_basis_input_, *self._basis2.set_input_shape( *xi[self._basis1._n_input_dimensionality :] - )._n_basis_input, - ) - self._n_output_features = ( - self._basis1.n_output_features + self._basis2.n_output_features + )._n_basis_input_, ) return self @@ -1209,18 +1256,18 @@ def _get_feature_slicing( _merge_slicing_dicts : Merges multiple slicing dictionaries, handling keys conflicts. """ # Set default values for n_inputs and start_slice if not provided - n_inputs = n_inputs or self._n_basis_input + n_inputs = n_inputs or self._n_basis_input_ start_slice = start_slice or 0 # If the instance is of AdditiveBasis type, handle slicing for the additive components split_dict, start_slice = self._basis1._get_feature_slicing( - n_inputs[: len(self._basis1._n_basis_input)], + n_inputs[: len(self._basis1._n_basis_input_)], start_slice, split_by_input=split_by_input, ) sp2, start_slice = self._basis2._get_feature_slicing( - n_inputs[len(self._basis1._n_basis_input) :], + n_inputs[len(self._basis1._n_basis_input_) :], start_slice, split_by_input=split_by_input, ) @@ -1249,11 +1296,6 @@ class MultiplicativeBasis(CompositeBasisMixin, Basis): basis2 : Second basis object to multiply. - Attributes - ---------- - n_basis_funcs : - Number of basis functions. - Examples -------- >>> # Generate sample data @@ -1273,39 +1315,30 @@ class MultiplicativeBasis(CompositeBasisMixin, Basis): """ def __init__(self, basis1: Basis, basis2: Basis) -> None: - self.n_basis_funcs = basis1.n_basis_funcs * basis2.n_basis_funcs - CompositeBasisMixin.__init__(self) - super().__init__(self.n_basis_funcs, mode="eval") + CompositeBasisMixin.__init__(self, basis1, basis2) + Basis.__init__(self, mode="eval") + self._label = "(" + basis1.label + " * " + basis2.label + ")" self._n_input_dimensionality = ( basis1._n_input_dimensionality + basis2._n_input_dimensionality ) - self._n_basis_input = None - self._n_output_features = None - self._label = "(" + basis1.label + " * " + basis2.label + ")" - self._basis1 = basis1 - self._basis2 = basis2 - - def _check_n_basis_min(self) -> None: - pass - - def set_kernel(self, *xi: NDArray) -> Basis: - """Call fit on the multiplied basis. - - If any of the added basis is in "conv" mode, it will prepare its kernels for the convolution. - Parameters - ---------- - *xi: - The sample inputs. Unused, necessary to conform to ``scikit-learn`` API. + @property + def n_basis_funcs(self): + """Compute the n-basis function runtime. - Returns - ------- - : - The MultiplicativeBasis ready to be evaluated. + This plays well with cross-validation where the number of basis function of the + underlying bases can be changed. It must be read-only since the number of basis + is determined by the two basis elements and the type of composition. """ - self._basis1.set_kernel() - self._basis2.set_kernel() - return self + return self.basis1.n_basis_funcs * self.basis2.n_basis_funcs + + @property + def n_output_features(self): + out1 = getattr(self._basis1, "n_output_features", None) + out2 = getattr(self._basis2, "n_output_features", None) + if out1 is None or out2 is None: + return None + return out1 * out2 @support_pynapple(conv_type="numpy") @check_transform_input @@ -1423,16 +1456,13 @@ def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis: 25200 """ - self._n_basis_input = ( + self._n_basis_input_ = ( *self._basis1.set_input_shape( *xi[: self._basis1._n_input_dimensionality] - )._n_basis_input, + )._n_basis_input_, *self._basis2.set_input_shape( *xi[self._basis1._n_input_dimensionality :] - )._n_basis_input, - ) - self._n_output_features = ( - self._basis1.n_output_features * self._basis2.n_output_features + )._n_basis_input_, ) return self diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 098aeb34..9b208e31 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -2,8 +2,10 @@ from __future__ import annotations +import abc import copy import inspect +import warnings from typing import TYPE_CHECKING, Optional, Tuple, Union import numpy as np @@ -20,8 +22,11 @@ class EvalBasisMixin: """Mixin class for evaluational basis.""" - def __init__(self, bounds: Optional[Tuple[float, float]] = None): + def __init__( + self, n_basis_funcs: int, bounds: Optional[Tuple[float, float]] = None + ): self.bounds = bounds + self._n_basis_funcs = n_basis_funcs def _compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor): """Evaluate basis at sample points. @@ -51,9 +56,32 @@ def _compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor): out = self._evaluate(*(np.reshape(x, (x.shape[0], -1)) for x in xi)) return np.reshape(out, (out.shape[0], -1)) - def set_kernel(self) -> "EvalBasisMixin": + def setup_basis(self, *xi: NDArray) -> Basis: """ - Prepare or compute the convolutional kernel for the basis functions. + Set all basis states. + + This method corresponds sklearn transformer ``fit``. As fit, it must receive the input and + it must set all basis states, i.e. kernel_ and all the states relative to the input shape. + The difference between this method and the transformer ``fit`` is in the expected input structure, + where the transformer ``fit`` method requires the inputs to be concatenated in a 2D array, while here + each input is provided as a separate time series for each basis element. + + Parameters + ---------- + xi: + Input arrays. + + Returns + ------- + : + The basis with ready for evaluation. + """ + self.set_input_shape(*xi) + return self + + def _set_input_independent_states(self) -> "EvalBasisMixin": + """ + Compute all the basis states that do not depend on the input. For EvalBasisMixin, this method might not perform any operation but simply return the instance itself, as no kernel preparation is necessary. @@ -94,9 +122,13 @@ def bounds(self, values: Union[None, Tuple[float, float]]): class ConvBasisMixin: """Mixin class for convolutional basis.""" - def __init__(self, window_size: int, conv_kwargs: Optional[dict] = None): + def __init__( + self, n_basis_funcs: int, window_size: int, conv_kwargs: Optional[dict] = None + ): + self.kernel_ = None self.window_size = window_size self.conv_kwargs = {} if conv_kwargs is None else conv_kwargs + self._n_basis_funcs = n_basis_funcs def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor): """Convolve basis functions with input time series. @@ -114,10 +146,18 @@ def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor): The input data over which to apply the basis transformation. The samples can be passed as multiple arguments, each representing a different dimension for multivariate inputs. + Notes + ----- + This method is intended to be 1-to-1 mappable to sklearn ``transform`` method of transformer. This + means that for the method to be callable, all the state attributes have to be pre-computed in a + method that is mappable to ``fit``, which for us is ``_fit_basis``. It is fundamental that both + methods behaves like the corresponding transformer method, with the only difference being the input + structure: a single (X, y) pair for the transformer, a number of time series for the Basis. + """ if self.kernel_ is None: raise ValueError( - "You must call `_set_kernel` before `_compute_features`! " + "You must call `setup_basis` before `_compute_features`! " "Convolution kernel is not set." ) # before calling the convolve, check that the input matches @@ -127,6 +167,38 @@ def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor): # make sure to return a matrix return np.reshape(conv, newshape=(conv.shape[0], -1)) + def setup_basis(self, *xi: NDArray) -> Basis: + """ + Set all basis states. + + This method corresponds sklearn transformer ``fit``. As fit, it must receive the input and + it must set all basis states, i.e. kernel_ and all the states relative to the input shape. + The difference between this method and the transformer ``fit`` is in the expected input structure, + where the transformer ``fit`` method requires the inputs to be concatenated in a 2D array, while here + each input is provided as a separate time series for each basis element. + + Parameters + ---------- + xi: + Input arrays. + + Returns + ------- + : + The basis with ready for evaluation. + """ + self.set_kernel() + self.set_input_shape(*xi) + return self + + def _set_input_independent_states(self): + """ + Compute all the basis states that do not depend on the input. + + For Conv mixin the only attribute is the kernel. + """ + return self.set_kernel() + def set_kernel(self) -> "ConvBasisMixin": """ Prepare or compute the convolutional kernel for the basis functions. @@ -160,6 +232,11 @@ def window_size(self): @window_size.setter def window_size(self, window_size): """Setter for the window size parameter.""" + self._check_window_size(window_size) + + self._window_size = window_size + + def _check_window_size(self, window_size): if window_size is None: raise ValueError("You must provide a window_size!") @@ -168,8 +245,6 @@ def window_size(self, window_size): f"`window_size` must be a positive integer. {window_size} provided instead!" ) - self._window_size = window_size - @property def conv_kwargs(self): """The convolutional kwargs. @@ -227,6 +302,13 @@ def _check_convolution_kwargs(conv_kwargs: dict): f"Allowed convolution keyword arguments are: {convolve_configs}." ) + def _check_has_kernel(self) -> None: + """Check that the kernel is pre-computed.""" + if self.kernel_ is None: + raise ValueError( + "You must call `_set_kernel` before `_compute_features` for Conv basis." + ) + class BasisTransformerMixin: """Mixin class for constructing a transformer.""" @@ -244,7 +326,7 @@ def to_transformer(self) -> TransformerBasis: >>> from sklearn.model_selection import GridSearchCV >>> # load some data >>> X, y = np.random.normal(size=(30, 1)), np.random.poisson(size=30) - >>> basis = nmo.basis.RaisedCosineLinearEval(10).to_transformer() + >>> basis = nmo.basis.RaisedCosineLinearEval(10).set_input_shape(1).to_transformer() >>> glm = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=1.) >>> pipeline = Pipeline([("basis", basis), ("glm", glm)]) >>> param_grid = dict( @@ -258,7 +340,7 @@ def to_transformer(self) -> TransformerBasis: ... ) >>> gridsearch = gridsearch.fit(X, y) """ - return TransformerBasis(copy.deepcopy(self)) + return TransformerBasis(self) class CompositeBasisMixin: @@ -268,28 +350,82 @@ class CompositeBasisMixin: (AdditiveBasis and MultiplicativeBasis). """ + def __init__(self, basis1: Basis, basis2: Basis): + # deep copy to avoid changes directly to the 1d basis to be reflected + # in the composite basis. + self.basis1 = copy.deepcopy(basis1) + self.basis2 = copy.deepcopy(basis2) + + # set parents + self.basis1._parent = self + self.basis2._parent = self + + shapes = ( + *(bas1._input_shape_ for bas1 in basis1._list_components()), + *(bas2._input_shape_ for bas2 in basis2._list_components()), + ) + # if all bases where set, then set input for composition. + set_bases = (s is not None for s in shapes) + + if all(set_bases): + # pass down the input shapes + self.set_input_shape(*shapes) + elif any(set_bases): + warnings.warn( + "Only some of the basis where initialized with `set_input_shape`, " + "please initialize the composite basis before computing features.", + category=UserWarning, + ) + + @property + @abc.abstractmethod + def n_basis_funcs(self): + """Read only property for composite bases.""" + pass + def _check_n_basis_min(self) -> None: pass - def set_kernel(self, *xi: NDArray) -> Basis: - """Call set_kernel on the basis elements. + def setup_basis(self, *xi: NDArray) -> Basis: + """ + Set all basis states. - If any of the basis elements is in "conv" mode, it will prepare its kernels for the convolution. + This method corresponds sklearn transformer ``fit``. As fit, it must receive the input and + it must set all basis states, i.e. kernel_ and all the states relative to the input shape. + The difference between this method and the transformer ``fit`` is in the expected input structure, + where the transformer ``fit`` method requires the inputs to be concatenated in a 2D array, while here + each input is provided as a separate time series for each basis element. Parameters ---------- - *xi: - The sample inputs. Unused, necessary to conform to ``scikit-learn`` API. + xi: + Input arrays. Returns ------- : - The basis ready to be evaluated. + The basis with ready for evaluation. """ - self._basis1.set_kernel() - self._basis2.set_kernel() + # setup both input independent + self._set_input_independent_states() + + # and input dependent states + self.set_input_shape(*xi) + return self + def _set_input_independent_states(self): + """ + Compute the input dependent states for traversing the composite basis. + + Returns + ------- + : + The basis with the states stored as attributes of each component. + """ + self.basis1._set_input_independent_states() + self.basis2._set_input_independent_states() + def _check_input_shape_consistency(self, *xi: NDArray): """Check the input shape consistency for all basis elements.""" self._basis1._check_input_shape_consistency( @@ -298,3 +434,31 @@ def _check_input_shape_consistency(self, *xi: NDArray): self._basis2._check_input_shape_consistency( *xi[self._basis1._n_input_dimensionality :] ) + + @property + def basis1(self): + return self._basis1 + + @basis1.setter + def basis1(self, bas: Basis): + self._basis1 = bas + + @property + def basis2(self): + return self._basis2 + + @basis2.setter + def basis2(self, bas: Basis): + self._basis2 = bas + + def _list_components(self): + """List all basis components. + + Reimplements the default behavior by iteratively calling _list_components of the + elements. + + Returns + ------- + A list with all 1d basis components. + """ + return self._basis1._list_components() + self._basis2._list_components() diff --git a/src/nemos/basis/_decaying_exponential.py b/src/nemos/basis/_decaying_exponential.py index 65a71a3e..7df05947 100644 --- a/src/nemos/basis/_decaying_exponential.py +++ b/src/nemos/basis/_decaying_exponential.py @@ -21,8 +21,6 @@ class OrthExponentialBasis(Basis, abc.ABC): Parameters ---------- - n_basis_funcs - Number of basis functions. decay_rates : Decay rates of the exponentials, shape ``(n_basis_funcs,)``. mode : @@ -35,13 +33,11 @@ class OrthExponentialBasis(Basis, abc.ABC): def __init__( self, - n_basis_funcs: int, decay_rates: NDArray[np.floating], mode="eval", label: Optional[str] = "OrthExponentialBasis", ): super().__init__( - n_basis_funcs, mode=mode, label=label, ) diff --git a/src/nemos/basis/_raised_cosine_basis.py b/src/nemos/basis/_raised_cosine_basis.py index 07c3ae0a..0521a683 100644 --- a/src/nemos/basis/_raised_cosine_basis.py +++ b/src/nemos/basis/_raised_cosine_basis.py @@ -21,8 +21,6 @@ class RaisedCosineBasisLinear(Basis, abc.ABC): Parameters ---------- - n_basis_funcs : - The number of basis functions. mode : The mode of operation. 'eval' for evaluation at sample points, 'conv' for convolutional operation. @@ -42,13 +40,11 @@ class RaisedCosineBasisLinear(Basis, abc.ABC): def __init__( self, - n_basis_funcs: int, mode="eval", width: float = 2.0, label: Optional[str] = "RaisedCosineBasisLinear", ) -> None: super().__init__( - n_basis_funcs, mode=mode, label=label, ) @@ -234,7 +230,6 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear, abc.ABC): def __init__( self, - n_basis_funcs: int, mode="eval", width: float = 2.0, time_scaling: float = None, @@ -242,7 +237,6 @@ def __init__( label: Optional[str] = "RaisedCosineBasisLog", ) -> None: super().__init__( - n_basis_funcs, mode=mode, width=width, label=label, diff --git a/src/nemos/basis/_spline_basis.py b/src/nemos/basis/_spline_basis.py index d9969029..5fc4c38e 100644 --- a/src/nemos/basis/_spline_basis.py +++ b/src/nemos/basis/_spline_basis.py @@ -21,8 +21,6 @@ class SplineBasis(Basis, abc.ABC): Parameters ---------- - n_basis_funcs : - Number of basis functions. mode : The mode of operation. 'eval' for evaluation at sample points, 'conv' for convolutional operation. @@ -40,14 +38,12 @@ class SplineBasis(Basis, abc.ABC): def __init__( self, - n_basis_funcs: int, order: int = 2, label: Optional[str] = None, mode: Literal["conv", "eval"] = "eval", ) -> None: self.order = order super().__init__( - n_basis_funcs, label=label, mode=mode, ) @@ -158,9 +154,6 @@ class MSplineBasis(SplineBasis, abc.ABC): Parameters ---------- - n_basis_funcs : - The number of basis functions to generate. More basis functions allow for - more flexible data modeling but can lead to overfitting. mode : The mode of operation. 'eval' for evaluation at sample points, 'conv' for convolutional operation. @@ -198,13 +191,11 @@ class MSplineBasis(SplineBasis, abc.ABC): def __init__( self, - n_basis_funcs: int, mode: Literal["eval", "conv"] = "eval", order: int = 2, label: Optional[str] = "MSplineEval", ) -> None: super().__init__( - n_basis_funcs, mode=mode, order=order, label=label, @@ -301,8 +292,6 @@ class BSplineBasis(SplineBasis, abc.ABC): Parameters ---------- - n_basis_funcs : - Number of basis functions. mode : The mode of operation. ``'eval'`` for evaluation at sample points, 'conv' for convolutional operation. @@ -328,13 +317,11 @@ class BSplineBasis(SplineBasis, abc.ABC): def __init__( self, - n_basis_funcs: int, mode="eval", order: int = 4, label: Optional[str] = "BSplineBasis", ): super().__init__( - n_basis_funcs, mode=mode, order=order, label=label, @@ -419,8 +406,6 @@ class CyclicBSplineBasis(SplineBasis, abc.ABC): Parameters ---------- - n_basis_funcs : - Number of basis functions. mode : The mode of operation. 'eval' for evaluation at sample points, 'conv' for convolutional operation. @@ -442,13 +427,11 @@ class CyclicBSplineBasis(SplineBasis, abc.ABC): def __init__( self, - n_basis_funcs: int, mode="eval", order: int = 4, label: Optional[str] = "CyclicBSplineBasis", ): super().__init__( - n_basis_funcs, mode=mode, order=order, label=label, diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index 83f4f2e3..91865028 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -1,7 +1,9 @@ from __future__ import annotations import copy -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List + +import numpy as np from ..typing import FeatureMatrix @@ -63,12 +65,28 @@ def __init__(self, basis: Basis): self._basis = copy.deepcopy(basis) @staticmethod - def _unpack_inputs(X: FeatureMatrix): - """Unpack inputs without using transpose. + def _check_initialized(basis): + if basis._n_basis_input_ is None: + raise RuntimeError( + "Cannot apply TransformerBasis: the provided basis has no defined input shape. " + "Please call `set_input_shape` before calling `fit`, `transform`, or " + "`fit_transform`." + ) + + @property + def basis(self): + return self._basis + + @basis.setter + def basis(self, basis): + self._check_initialized(basis) + self._basis = basis + + def _unpack_inputs(self, X: FeatureMatrix) -> List: + """Unpack inputs. Unpack horizontally stacked inputs using slicing. This works gracefully with ``pynapple``, - returning a list of Tsd objects. Attempt to unpack using *X.T will raise a ``pynapple`` - exception since ``pynapple`` assumes that the time axis is the first axis. + returning a list of Tsd objects. Parameters ---------- @@ -78,10 +96,19 @@ def _unpack_inputs(X: FeatureMatrix): Returns ------- : - A tuple of each individual input. + A list of each individual input. """ - return (X[:, k] for k in range(X.shape[1])) + n_samples = X.shape[0] + out = [] + cc = 0 + for i, bas in enumerate(self._list_components()): + n_input = self._n_basis_input_[i] + out.append( + np.reshape(X[:, cc : cc + n_input], (n_samples, *bas._input_shape_)) + ) + cc += n_input + return out def fit(self, X: FeatureMatrix, y=None): """ @@ -110,11 +137,11 @@ def fit(self, X: FeatureMatrix, y=None): >>> X = np.random.normal(size=(100, 2)) >>> # Define and fit tranformation basis - >>> basis = MSplineEval(10) + >>> basis = MSplineEval(10).set_input_shape(2) >>> transformer = TransformerBasis(basis) >>> transformer_fitted = transformer.fit(X) """ - self._basis.set_kernel() + self._basis.setup_basis(*self._unpack_inputs(X)) return self def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: @@ -141,7 +168,7 @@ def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: >>> # Example input >>> X = np.random.normal(size=(10000, 2)) - >>> basis = MSplineConv(10, window_size=200) + >>> basis = MSplineConv(10, window_size=200).set_input_shape(2) >>> transformer = TransformerBasis(basis) >>> # Before calling `fit` the convolution kernel is not set >>> transformer.kernel_ @@ -152,7 +179,7 @@ def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: (200, 10) >>> # Transform basis - >>> feature_transformed = transformer.transform(X[:, 0:1]) + >>> feature_transformed = transformer.transform(X) """ # transpose does not work with pynapple # can't use func(*X.T) to unwrap @@ -187,7 +214,7 @@ def fit_transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: >>> X = np.random.normal(size=(100, 1)) >>> # Define tranformation basis - >>> basis = MSplineEval(10) + >>> basis = MSplineEval(10).set_input_shape(1) >>> transformer = TransformerBasis(basis) >>> # Fit and transform basis diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 0fee5651..702f7b56 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -83,10 +83,9 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "BSplineEval", ): - EvalBasisMixin.__init__(self, bounds=bounds) + EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds) BSplineBasis.__init__( self, - n_basis_funcs, mode="eval", order=order, label=label, @@ -237,10 +236,14 @@ def __init__( label: Optional[str] = "BSplineConv", conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) + ConvBasisMixin.__init__( + self, + n_basis_funcs=n_basis_funcs, + window_size=window_size, + conv_kwargs=conv_kwargs, + ) BSplineBasis.__init__( self, - n_basis_funcs, mode="conv", order=order, label=label, @@ -378,10 +381,9 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "CyclicBSplineEval", ): - EvalBasisMixin.__init__(self, bounds=bounds) + EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds) CyclicBSplineBasis.__init__( self, - n_basis_funcs, mode="eval", order=order, label=label, @@ -524,10 +526,14 @@ def __init__( label: Optional[str] = "CyclicBSplineConv", conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) + ConvBasisMixin.__init__( + self, + n_basis_funcs=n_basis_funcs, + window_size=window_size, + conv_kwargs=conv_kwargs, + ) CyclicBSplineBasis.__init__( self, - n_basis_funcs, mode="conv", order=order, label=label, @@ -689,10 +695,9 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "MSplineEval", ): - EvalBasisMixin.__init__(self, bounds=bounds) + EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds) MSplineBasis.__init__( self, - n_basis_funcs, mode="eval", order=order, label=label, @@ -859,10 +864,14 @@ def __init__( label: Optional[str] = "MSplineConv", conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) + ConvBasisMixin.__init__( + self, + n_basis_funcs=n_basis_funcs, + window_size=window_size, + conv_kwargs=conv_kwargs, + ) MSplineBasis.__init__( self, - n_basis_funcs, mode="conv", order=order, label=label, @@ -1008,10 +1017,9 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "RaisedCosineLinearEval", ): - EvalBasisMixin.__init__(self, bounds=bounds) + EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds) RaisedCosineBasisLinear.__init__( self, - n_basis_funcs, width=width, mode="eval", label=label, @@ -1155,10 +1163,14 @@ def __init__( label: Optional[str] = "RaisedCosineLinearConv", conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) + ConvBasisMixin.__init__( + self, + n_basis_funcs=n_basis_funcs, + window_size=window_size, + conv_kwargs=conv_kwargs, + ) RaisedCosineBasisLinear.__init__( self, - n_basis_funcs, mode="conv", width=width, label=label, @@ -1311,10 +1323,9 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "RaisedCosineLogEval", ): - EvalBasisMixin.__init__(self, bounds=bounds) + EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds) RaisedCosineBasisLog.__init__( self, - n_basis_funcs, width=width, time_scaling=time_scaling, enforce_decay_to_zero=enforce_decay_to_zero, @@ -1470,10 +1481,14 @@ def __init__( label: Optional[str] = "RaisedCosineLogConv", conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) + ConvBasisMixin.__init__( + self, + n_basis_funcs=n_basis_funcs, + window_size=window_size, + conv_kwargs=conv_kwargs, + ) RaisedCosineBasisLog.__init__( self, - n_basis_funcs, mode="conv", width=width, time_scaling=time_scaling, @@ -1608,10 +1623,9 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "OrthExponentialEval", ): - EvalBasisMixin.__init__(self, bounds=bounds) + EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds) OrthExponentialBasis.__init__( self, - n_basis_funcs, decay_rates=decay_rates, mode="eval", label=label, @@ -1751,14 +1765,21 @@ def __init__( label: Optional[str] = "OrthExponentialConv", conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) + ConvBasisMixin.__init__( + self, + n_basis_funcs=n_basis_funcs, + window_size=window_size, + conv_kwargs=conv_kwargs, + ) OrthExponentialBasis.__init__( self, - n_basis_funcs, mode="conv", decay_rates=decay_rates, label=label, ) + # re-check window size because n_basis_funcs is not set yet when the + # property setter runs the first check. + self._check_window_size(self.window_size) @add_docstring("evaluate_on_grid", OrthExponentialBasis) def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: @@ -1850,3 +1871,31 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ return super().set_input_shape(xi) + + def _check_window_size(self, window_size: int): + """OrthExponentialBasis specific window size check.""" + super()._check_window_size(window_size) + # if n_basis_funcs is not yet initialized, skip check + n_basis = getattr(self, "n_basis_funcs", None) + if n_basis and window_size < n_basis: + raise ValueError( + "OrthExponentialConv basis requires at least a window_size larger then the number " + f"of basis functions. window_size is {window_size}, n_basis_funcs while" + f"is {self.n_basis_funcs}." + ) + + def set_kernel(self): + try: + super().set_kernel() + except ValueError as e: + if "OrthExponentialBasis requires at least as many" in str(e): + raise ValueError( + "Cannot set the kernels for OrthExponentialBasis when `window_size` is smaller " + "than `n_basis_funcs.\n" + "Please, increase the window size or reduce the number of basis functions. " + f"Current `window_size` is {self.window_size}, while `n_basis_funcs` is " + f"{self.n_basis_funcs}." + ) + else: + raise e + return self diff --git a/tests/conftest.py b/tests/conftest.py index eb88ed10..3daba960 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,8 @@ and loading predefined parameters for testing various functionalities of the NeMoS library. """ +import abc + import jax import jax.numpy as jnp import numpy as np @@ -16,11 +18,112 @@ import pytest import nemos as nmo +import nemos._inspect_utils as inspect_utils +import nemos.basis.basis as basis +from nemos.basis import AdditiveBasis, MultiplicativeBasis +from nemos.basis._basis import Basis # shut-off conversion warnings nap.nap_config.suppress_conversion_warnings = True +@pytest.fixture() +def basis_class_specific_params(): + """Returns all the params for each class.""" + all_cls = list_all_basis_classes("Conv") + list_all_basis_classes("Eval") + return {cls.__name__: cls._get_param_names() for cls in all_cls} + + +class BasisFuncsTesting(abc.ABC): + """ + An abstract base class that sets the foundation for individual basis function testing. + This class requires an implementation of a 'cls' method, which is utilized by the meta-test + that verifies if all basis functions are properly tested. + """ + + @abc.abstractmethod + def cls(self): + pass + + +class CombinedBasis(BasisFuncsTesting): + """ + This class is used to run tests on combination operations (e.g., addition, multiplication) among Basis functions. + + Properties: + - cls: Class (default = None) + """ + + cls = None + + @staticmethod + def instantiate_basis( + n_basis, basis_class, class_specific_params, window_size=10, **kwargs + ): + """Instantiate and return two basis of the type specified.""" + + # Set non-optional args + default_kwargs = { + "n_basis_funcs": n_basis, + "window_size": window_size, + "decay_rates": np.arange(1, 1 + n_basis), + } + repeated_keys = set(default_kwargs.keys()).intersection(kwargs.keys()) + if repeated_keys: + raise ValueError( + "Cannot set `n_basis_funcs, window_size, decay_rates` with kwargs" + ) + + # Merge with provided extra kwargs + kwargs = {**default_kwargs, **kwargs} + + if basis_class == AdditiveBasis: + kwargs_mspline = inspect_utils.trim_kwargs( + basis.MSplineEval, kwargs, class_specific_params + ) + kwargs_raised_cosine = inspect_utils.trim_kwargs( + basis.RaisedCosineLinearConv, kwargs, class_specific_params + ) + b1 = basis.MSplineEval(**kwargs_mspline) + b2 = basis.RaisedCosineLinearConv(**kwargs_raised_cosine) + basis_obj = b1 + b2 + elif basis_class == MultiplicativeBasis: + kwargs_mspline = inspect_utils.trim_kwargs( + basis.MSplineEval, kwargs, class_specific_params + ) + kwargs_raised_cosine = inspect_utils.trim_kwargs( + basis.RaisedCosineLinearConv, kwargs, class_specific_params + ) + b1 = basis.MSplineEval(**kwargs_mspline) + b2 = basis.RaisedCosineLinearConv(**kwargs_raised_cosine) + basis_obj = b1 * b2 + else: + basis_obj = basis_class( + **inspect_utils.trim_kwargs(basis_class, kwargs, class_specific_params) + ) + return basis_obj + + +# automatic define user accessible basis and check the methods +def list_all_basis_classes(filter_basis="all") -> list[type]: + """ + Return all the classes in nemos.basis which are a subclass of Basis, + which should be all concrete classes except TransformerBasis. + """ + all_basis = [ + class_obj + for _, class_obj in inspect_utils.get_non_abstract_classes(basis) + if issubclass(class_obj, Basis) + ] + [ + bas + for _, bas in inspect_utils.get_non_abstract_classes(nmo.basis._basis) + if bas != basis.TransformerBasis + ] + if filter_basis != "all": + all_basis = [a for a in all_basis if filter_basis in a.__name__] + return all_basis + + # Sample subclass to test instantiation and methods class MockRegressor(nmo.base_regressor.BaseRegressor): """ diff --git a/tests/test_basis.py b/tests/test_basis.py index d6168bc8..86c58fa2 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -1,4 +1,3 @@ -import abc import inspect import itertools import pickle @@ -11,9 +10,8 @@ import numpy as np import pynapple as nap import pytest -from sklearn.base import clone as sk_clone +from conftest import BasisFuncsTesting, CombinedBasis, list_all_basis_classes -import nemos as nmo import nemos._inspect_utils as inspect_utils import nemos.basis.basis as basis import nemos.convolve as convolve @@ -34,33 +32,6 @@ def extra_decay_rates(cls, n_basis): return {} -# automatic define user accessible basis and check the methods -def list_all_basis_classes(filter_basis="all") -> list[type]: - """ - Return all the classes in nemos.basis which are a subclass of Basis, - which should be all concrete classes except TransformerBasis. - """ - all_basis = [ - class_obj - for _, class_obj in inspect_utils.get_non_abstract_classes(basis) - if issubclass(class_obj, Basis) - ] + [ - bas - for _, bas in inspect_utils.get_non_abstract_classes(nmo.basis._basis) - if bas != basis.TransformerBasis - ] - if filter_basis != "all": - all_basis = [a for a in all_basis if filter_basis in a.__name__] - return all_basis - - -@pytest.fixture() -def class_specific_params(): - """Returns all the params for each class.""" - all_cls = list_all_basis_classes("Conv") + list_all_basis_classes("Eval") - return {cls.__name__: cls._get_param_names() for cls in all_cls} - - def test_all_basis_are_tested() -> None: """Meta-test. @@ -137,11 +108,11 @@ def test_all_basis_are_tested() -> None: ], ) def test_example_docstrings_add( - basis_cls, method_name, descr_match, class_specific_params + basis_cls, method_name, descr_match, basis_class_specific_params ): basis_instance = CombinedBasis().instantiate_basis( - 5, basis_cls, class_specific_params, window_size=10 + 5, basis_cls, basis_class_specific_params, window_size=10 ) method = getattr(basis_instance, method_name) doc = method.__doc__ @@ -285,19 +256,6 @@ def test_expected_output_split_by_feature(basis_instance, super_class): np.testing.assert_array_equal(xx[~nans], x[~nans]) -class BasisFuncsTesting(abc.ABC): - """ - An abstract base class that sets the foundation for individual basis function testing. - This class requires an implementation of a 'cls' method, which is utilized by the meta-test - that verifies if all basis functions are properly tested. - """ - - @abc.abstractmethod - def cls(self): - pass - - -# Auto-generated file with stripped classes and shared methods @pytest.mark.parametrize( "cls", [ @@ -343,12 +301,44 @@ def test_call_vmin_vmax(self, samples, vmin, vmax, expectation, cls): with expectation: bas._evaluate(samples) + @pytest.mark.parametrize("n_basis", [5, 6]) + @pytest.mark.parametrize("vmin, vmax", [(0, 1), (-1, 1)]) + @pytest.mark.parametrize("inp_num", [1, 2]) + def test_sklearn_clone_eval(self, cls, n_basis, vmin, vmax, inp_num): + bas = cls["eval"]( + n_basis, bounds=(vmin, vmax), **extra_decay_rates(cls["eval"], n_basis) + ) + bas.set_input_shape(inp_num) + bas2 = bas.__sklearn_clone__() + assert id(bas) != id(bas2) + assert np.all( + bas.__dict__.pop("decay_rates", True) + == bas2.__dict__.pop("decay_rates", True) + ) + assert bas.__dict__ == bas2.__dict__ + + @pytest.mark.parametrize("n_basis", [5, 6]) + @pytest.mark.parametrize("ws", [10, 20]) + @pytest.mark.parametrize("inp_num", [1, 2]) + def test_sklearn_clone_conv(self, cls, n_basis, ws, inp_num): + bas = cls["conv"]( + n_basis, window_size=ws, **extra_decay_rates(cls["eval"], n_basis) + ) + bas.set_input_shape(inp_num) + bas2 = bas.__sklearn_clone__() + assert id(bas) != id(bas2) + assert np.all( + bas.__dict__.pop("decay_rates", True) + == bas2.__dict__.pop("decay_rates", True) + ) + assert bas.__dict__ == bas2.__dict__ + @pytest.mark.parametrize( "attribute, value", [ ("label", None), ("label", "label"), - ("n_basis_input", 1), + ("n_basis_input_", 1), ("n_output_features", 5), ], ) @@ -442,10 +432,10 @@ def test_set_num_basis_input(self, n_input, cls): bas = cls["conv"]( n_basis_funcs=5, window_size=10, **extra_decay_rates(cls["conv"], 5) ) - assert bas.n_basis_input is None + assert bas.n_basis_input_ is None bas.compute_features(np.random.randn(20, n_input)) - assert bas.n_basis_input == (n_input,) - assert bas._n_basis_input == (n_input,) + assert bas.n_basis_input_ == (n_input,) + assert bas._n_basis_input_ == (n_input,) @pytest.mark.parametrize( "bounds, samples, nan_idx, mn, mx", @@ -520,7 +510,7 @@ def test_vmin_vmax_init(self, bounds, expectation, cls): @pytest.mark.parametrize("n_basis", [6, 7]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_call_basis_number(self, n_basis, mode, kwargs, cls): @@ -552,7 +542,7 @@ def test_call_equivalent_in_conv(self, n_basis, cls): ], ) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) @pytest.mark.parametrize("n_basis", [6]) def test_call_input_num(self, num_input, n_basis, mode, kwargs, expectation, cls): @@ -572,7 +562,7 @@ def test_call_input_num(self, num_input, n_basis, mode, kwargs, expectation, cls ) @pytest.mark.parametrize("n_basis", [6]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_call_input_shape(self, inp, mode, kwargs, expectation, n_basis, cls): bas = cls[mode]( @@ -586,7 +576,7 @@ def test_call_input_shape(self, inp, mode, kwargs, expectation, n_basis, cls): @pytest.mark.parametrize("n_basis", [6]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_call_nan_location(self, mode, kwargs, n_basis, cls): bas = cls[mode]( @@ -619,7 +609,7 @@ def test_call_input_type(self, samples, expectation, n_basis, cls): bas._evaluate(samples) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_call_nan(self, mode, kwargs, cls): bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) @@ -629,7 +619,7 @@ def test_call_nan(self, mode, kwargs, cls): @pytest.mark.parametrize("n_basis", [6, 7]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_call_non_empty(self, n_basis, mode, kwargs, cls): bas = cls[mode]( @@ -640,7 +630,7 @@ def test_call_non_empty(self, n_basis, mode, kwargs, cls): @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_call_sample_axis(self, time_axis_shape, mode, kwargs, cls): bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) @@ -657,7 +647,7 @@ def test_call_sample_axis(self, time_axis_shape, mode, kwargs, cls): ], ) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_call_sample_range(self, mn, mx, expectation, mode, kwargs, cls): bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) @@ -722,7 +712,7 @@ def test_compute_features_conv_input( order, width, cls, - class_specific_params, + basis_class_specific_params, ): x = np.ones(input_shape) @@ -737,7 +727,9 @@ def test_compute_features_conv_input( ) # figure out which kwargs needs to be removed - kwargs = inspect_utils.trim_kwargs(cls["conv"], kwargs, class_specific_params) + kwargs = inspect_utils.trim_kwargs( + cls["conv"], kwargs, basis_class_specific_params + ) basis_obj = cls["conv"](**kwargs) out = basis_obj.compute_features(x) @@ -913,7 +905,7 @@ def test_convolution_is_performed(self, cls): @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_evaluate_on_grid_basis_size(self, sample_size, mode, kwargs, cls): if "OrthExp" in cls["eval"].__name__: @@ -932,7 +924,7 @@ def test_evaluate_on_grid_basis_size(self, sample_size, mode, kwargs, cls): @pytest.mark.parametrize("n_input", [0, 1, 2]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_evaluate_on_grid_input_number(self, n_input, mode, kwargs, cls): basis_obj = cls[mode]( @@ -957,7 +949,7 @@ def test_evaluate_on_grid_input_number(self, n_input, mode, kwargs, cls): @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_evaluate_on_grid_meshgrid_size(self, sample_size, mode, kwargs, cls): if "OrthExp" in cls["eval"].__name__: @@ -991,7 +983,7 @@ def test_fit_kernel_shape(self, cls): @pytest.mark.parametrize( "mode, ws, expectation", [ - ("conv", 2, does_not_raise()), + ("conv", 5, does_not_raise()), ( "conv", -1, @@ -1036,9 +1028,9 @@ def test_init_window_size(self, mode, ws, expectation, cls): n_basis_funcs=5, window_size=ws, **extra_decay_rates(cls[mode], 5) ) - @pytest.mark.parametrize("samples", [[], [0], [0, 0]]) + @pytest.mark.parametrize("samples", [[], [0] * 10, [0] * 11]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_non_empty_samples(self, samples, mode, kwargs, cls): if "OrthExp" in cls["eval"].__name__: @@ -1083,7 +1075,7 @@ def test_number_of_required_inputs_compute_features( basis_obj.compute_features(*inputs) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_pynapple_support(self, mode, kwargs, cls): bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) @@ -1181,7 +1173,7 @@ def test_set_params( decay_rates, conv_kwargs, cls, - class_specific_params, + basis_class_specific_params, ): """Test the read-only and read/write property of the parameters.""" pars = dict( @@ -1198,7 +1190,7 @@ def test_set_params( pars = { key: value for key, value in pars.items() - if key in class_specific_params[cls[mode].__name__] + if key in basis_class_specific_params[cls[mode].__name__] } keys = list(pars.keys()) @@ -1242,15 +1234,16 @@ def test_set_window_size(self, mode, expectation, cls): def test_transform_fails(self, cls): bas = cls["conv"]( - n_basis_funcs=5, window_size=3, **extra_decay_rates(cls["conv"], 5) + n_basis_funcs=5, window_size=5, **extra_decay_rates(cls["conv"], 5) ) with pytest.raises( - ValueError, match="You must call `_set_kernel` before `_compute_features`" + ValueError, match="You must call `setup_basis` before `_compute_features`" ): bas._compute_features(np.linspace(0, 1, 10)) def test_transformer_get_params(self, cls): bas = cls["eval"](n_basis_funcs=5, **extra_decay_rates(cls["eval"], 5)) + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) bas_transformer = bas.to_transformer() params_transf = bas_transformer.get_params() params_transf.pop("_basis") @@ -1355,7 +1348,7 @@ def test_decay_to_zero_basis_number_match(self, width): @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_minimum_number_of_basis_required_is_matched( self, n_basis_funcs, mode, kwargs @@ -1466,7 +1459,7 @@ def test_time_scaling_values(self, time_scaling, expectation, mode, kwargs): ], ) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_width_values(self, width, expectation, mode, kwargs): with expectation: @@ -1478,7 +1471,7 @@ class TestRaisedCosineLinearBasis(BasisFuncsTesting): @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_minimum_number_of_basis_required_is_matched( self, n_basis_funcs, mode, kwargs @@ -1553,7 +1546,7 @@ class TestMSplineBasis(BasisFuncsTesting): @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize("order", [-1, 0, 1, 2, 3, 4, 5]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_minimum_number_of_basis_required_is_matched( self, n_basis_funcs, order, mode, kwargs @@ -1654,6 +1647,50 @@ def test_vmin_vmax_eval_on_grid_scaling_effect_on_eval( class TestOrthExponentialBasis(BasisFuncsTesting): cls = {"eval": basis.OrthExponentialEval, "conv": basis.OrthExponentialConv} + @pytest.mark.parametrize( + "window_size, n_basis, expectation", + [ + ( + 4, + 5, + pytest.raises( + ValueError, + match="OrthExponentialConv basis requires at least a window_size", + ), + ), + (5, 5, does_not_raise()), + ], + ) + def test_window_size_at_init(self, window_size, n_basis, expectation): + decay_rates = np.asarray(np.arange(1, n_basis + 1), dtype=float) + with expectation: + self.cls["conv"](n_basis, decay_rates=decay_rates, window_size=window_size) + + @pytest.mark.parametrize( + "window_size, n_basis, expectation", + [ + ( + 4, + 5, + pytest.raises( + ValueError, + match="OrthExponentialConv basis requires at least a window_size", + ), + ), + (5, 5, does_not_raise()), + ], + ) + def test_window_size_at_init(self, window_size, n_basis, expectation): + decay_rates = np.asarray(np.arange(1, n_basis + 1), dtype=float) + obj = self.cls["conv"]( + n_basis, decay_rates=decay_rates, window_size=n_basis + 1 + ) + with expectation: + obj.window_size = window_size + + with expectation: + obj.set_params(window_size=window_size) + @pytest.mark.parametrize( "decay_rates", [[1, 2, 3], [0.01, 0.02, 0.001], [2, 1, 1, 2.4]] ) @@ -1729,7 +1766,7 @@ class TestBSplineBasis(BasisFuncsTesting): @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize("order", [1, 2, 3, 4, 5]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_minimum_number_of_basis_required_is_matched( self, n_basis_funcs, order, mode, kwargs @@ -1817,7 +1854,7 @@ class TestCyclicBSplineBasis(BasisFuncsTesting): @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize("order", [2, 3, 4, 5]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_minimum_number_of_basis_required_is_matched( self, n_basis_funcs, order, mode, kwargs @@ -1918,72 +1955,16 @@ def test_samples_range_matches_compute_features_requirements( basis_obj.compute_features(np.linspace(*sample_range, 100)) -class CombinedBasis(BasisFuncsTesting): - """ - This class is used to run tests on combination operations (e.g., addition, multiplication) among Basis functions. - - Properties: - - cls: Class (default = None) - """ - - cls = None - - @staticmethod - def instantiate_basis( - n_basis, basis_class, class_specific_params, window_size=10, **kwargs - ): - """Instantiate and return two basis of the type specified.""" - - # Set non-optional args - default_kwargs = { - "n_basis_funcs": n_basis, - "window_size": window_size, - "decay_rates": np.arange(1, 1 + n_basis), - } - repeated_keys = set(default_kwargs.keys()).intersection(kwargs.keys()) - if repeated_keys: - raise ValueError( - "Cannot set `n_basis_funcs, window_size, decay_rates` with kwargs" - ) - - # Merge with provided extra kwargs - kwargs = {**default_kwargs, **kwargs} - - if basis_class == AdditiveBasis: - kwargs_mspline = inspect_utils.trim_kwargs( - basis.MSplineEval, kwargs, class_specific_params - ) - kwargs_raised_cosine = inspect_utils.trim_kwargs( - basis.RaisedCosineLinearConv, kwargs, class_specific_params - ) - b1 = basis.MSplineEval(**kwargs_mspline) - b2 = basis.RaisedCosineLinearConv(**kwargs_raised_cosine) - basis_obj = b1 + b2 - elif basis_class == MultiplicativeBasis: - kwargs_mspline = inspect_utils.trim_kwargs( - basis.MSplineEval, kwargs, class_specific_params - ) - kwargs_raised_cosine = inspect_utils.trim_kwargs( - basis.RaisedCosineLinearConv, kwargs, class_specific_params - ) - b1 = basis.MSplineEval(**kwargs_mspline) - b2 = basis.RaisedCosineLinearConv(**kwargs_raised_cosine) - basis_obj = b1 * b2 - else: - basis_obj = basis_class( - **inspect_utils.trim_kwargs(basis_class, kwargs, class_specific_params) - ) - return basis_obj - - class TestAdditiveBasis(CombinedBasis): cls = {"eval": AdditiveBasis, "conv": AdditiveBasis} @pytest.mark.parametrize("samples", [[[0], []], [[], [0]], [[0, 0], [0, 0]]]) @pytest.mark.parametrize("base_cls", [basis.BSplineEval, basis.BSplineConv]) - def test_non_empty_samples(self, base_cls, samples, class_specific_params): + def test_non_empty_samples(self, base_cls, samples, basis_class_specific_params): kwargs = {"window_size": 2, "n_basis_funcs": 5} - kwargs = inspect_utils.trim_kwargs(base_cls, kwargs, class_specific_params) + kwargs = inspect_utils.trim_kwargs( + base_cls, kwargs, basis_class_specific_params + ) basis_obj = base_cls(**kwargs) + base_cls(**kwargs) if any(tuple(len(s) == 0 for s in samples)): with pytest.raises( @@ -2010,6 +1991,68 @@ def test_compute_features_input(self, eval_input): basis_obj = basis.MSplineEval(5) + basis.MSplineEval(5) basis_obj.compute_features(*eval_input) + @pytest.mark.parametrize("n_basis_a", [6]) + @pytest.mark.parametrize("n_basis_b", [5]) + @pytest.mark.parametrize("vmin, vmax", [(-1, 1)]) + @pytest.mark.parametrize("basis_a", list_all_basis_classes()) + @pytest.mark.parametrize("basis_b", list_all_basis_classes()) + @pytest.mark.parametrize("inp_num", [1, 2]) + def test_sklearn_clone( + self, + basis_a, + basis_b, + n_basis_a, + n_basis_b, + vmin, + vmax, + inp_num, + basis_class_specific_params, + ): + """Recursively check cloning.""" + basis_a_obj = self.instantiate_basis( + n_basis_a, basis_a, basis_class_specific_params, window_size=10 + ) + basis_a_obj = basis_a_obj.set_input_shape( + *([inp_num] * basis_a_obj._n_input_dimensionality) + ) + basis_b_obj = self.instantiate_basis( + n_basis_b, basis_b, basis_class_specific_params, window_size=15 + ) + basis_b_obj = basis_b_obj.set_input_shape( + *([inp_num] * basis_b_obj._n_input_dimensionality) + ) + add = basis_a_obj + basis_b_obj + + def filter_attributes(obj, exclude_keys): + return { + key: val for key, val in obj.__dict__.items() if key not in exclude_keys + } + + def compare(b1, b2): + assert id(b1) != id(b2) + assert b1.__class__.__name__ == b2.__class__.__name__ + if hasattr(b1, "basis1"): + compare(b1.basis1, b2.basis1) + compare(b1.basis2, b2.basis2) + # add all params that are not parent or _basis1,_basis2 + d1 = filter_attributes( + b1, exclude_keys=["_basis1", "_basis2", "_parent"] + ) + d2 = filter_attributes( + b2, exclude_keys=["_basis1", "_basis2", "_parent"] + ) + assert d1 == d2 + else: + decay_rates_b1 = b1.__dict__.get("_decay_rates", -1) + decay_rates_b2 = b2.__dict__.get("_decay_rates", -1) + assert np.array_equal(decay_rates_b1, decay_rates_b2) + d1 = filter_attributes(b1, exclude_keys=["_decay_rates", "_parent"]) + d2 = filter_attributes(b2, exclude_keys=["_decay_rates", "_parent"]) + assert d1 == d2 + + add2 = add.__sklearn_clone__() + compare(add, add2) + @pytest.mark.parametrize("n_basis_a", [5, 6]) @pytest.mark.parametrize("n_basis_b", [5, 6]) @pytest.mark.parametrize("sample_size", [10, 1000]) @@ -2024,7 +2067,7 @@ def test_compute_features_returns_expected_number_of_basis( basis_a, basis_b, window_size, - class_specific_params, + basis_class_specific_params, ): """ Test whether the evaluation of the `AdditiveBasis` results in a number of basis @@ -2032,10 +2075,10 @@ def test_compute_features_returns_expected_number_of_basis( """ # define the two basis basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj @@ -2063,16 +2106,16 @@ def test_sample_size_of_compute_features_matches_that_of_input( basis_a, basis_b, window_size, - class_specific_params, + basis_class_specific_params, ): """ Test whether the output sample size from `AdditiveBasis` compute_features function matches input sample size. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj eval_basis = basis_obj.compute_features( @@ -2100,17 +2143,17 @@ def test_number_of_required_inputs_compute_features( basis_a, basis_b, window_size, - class_specific_params, + basis_class_specific_params, ): """ Test whether the number of required inputs for the `compute_features` function matches the sum of the number of input samples from the two bases. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj required_dim = ( @@ -2132,16 +2175,22 @@ def test_number_of_required_inputs_compute_features( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_meshgrid_size( - self, sample_size, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, + sample_size, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + basis_class_specific_params, ): """ Test whether the resulting meshgrid size matches the sample size input. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) basis_obj = basis_a_obj + basis_b_obj res = basis_obj.evaluate_on_grid( @@ -2156,16 +2205,22 @@ def test_evaluate_on_grid_meshgrid_size( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_basis_size( - self, sample_size, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, + sample_size, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + basis_class_specific_params, ): """ Test whether the number sample size output by evaluate_on_grid matches the sample size of the input. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) basis_obj = basis_a_obj + basis_b_obj eval_basis = basis_obj.evaluate_on_grid( @@ -2179,17 +2234,23 @@ def test_evaluate_on_grid_basis_size( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_input_number( - self, n_input, basis_a, basis_b, n_basis_a, n_basis_b, class_specific_params + self, + n_input, + basis_a, + basis_b, + n_basis_a, + n_basis_b, + basis_class_specific_params, ): """ Test whether the number of inputs provided to `evaluate_on_grid` matches the sum of the number of input samples required from each of the basis objects. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) basis_obj = basis_a_obj + basis_b_obj inputs = [20] * n_input @@ -2211,7 +2272,13 @@ def test_evaluate_on_grid_input_number( @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) def test_pynapple_support_compute_features( - self, basis_a, basis_b, n_basis_a, n_basis_b, sample_size, class_specific_params + self, + basis_a, + basis_b, + n_basis_a, + n_basis_b, + sample_size, + basis_class_specific_params, ): iset = nap.IntervalSet(start=[0, 0.5], end=[0.49999, 1]) inp = nap.Tsd( @@ -2220,9 +2287,9 @@ def test_pynapple_support_compute_features( time_support=iset, ) basis_add = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) + self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) # compute_features the basis over pynapple Tsd objects out = basis_add.compute_features(*([inp] * basis_add._n_input_dimensionality)) @@ -2237,7 +2304,7 @@ def test_pynapple_support_compute_features( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) @pytest.mark.parametrize("num_input", [0, 1, 2, 3, 4, 5]) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) def test_call_input_num( self, n_basis_a, @@ -2246,13 +2313,13 @@ def test_call_input_num( basis_b, num_input, window_size, - class_specific_params, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj if num_input == basis_obj._n_input_dimensionality: @@ -2271,7 +2338,7 @@ def test_call_input_num( (np.linspace(0, 1, 10)[:, None], pytest.raises(ValueError)), ], ) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -2285,20 +2352,20 @@ def test_call_input_shape( inp, window_size, expectation, - class_specific_params, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj with expectation: basis_obj._evaluate(*([inp] * basis_obj._n_input_dimensionality)) @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -2311,25 +2378,31 @@ def test_call_sample_axis( basis_b, time_axis_shape, window_size, - class_specific_params, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj inp = [np.linspace(0, 1, time_axis_shape)] * basis_obj._n_input_dimensionality assert basis_obj._evaluate(*inp).shape[0] == time_axis_shape - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_nan( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): if ( basis_a == basis.OrthExponentialBasis @@ -2337,10 +2410,10 @@ def test_call_nan( ): return basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj inp = [np.linspace(0, 1, 10)] * basis_obj._n_input_dimensionality @@ -2353,40 +2426,46 @@ def test_call_nan( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_equivalent_in_conv( - self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, basis_class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=3 + n_basis_a, basis_a, basis_class_specific_params, window_size=9 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=3 + n_basis_b, basis_b, basis_class_specific_params, window_size=9 ) bas_eva = basis_a_obj + basis_b_obj basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=8 + n_basis_a, basis_a, basis_class_specific_params, window_size=8 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=8 + n_basis_b, basis_b, basis_class_specific_params, window_size=8 ) bas_con = basis_a_obj + basis_b_obj x = [np.linspace(0, 1, 10)] * bas_con._n_input_dimensionality assert np.all(bas_con._evaluate(*x) == bas_eva._evaluate(*x)) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_pynapple_support( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj + basis_b_obj x = np.linspace(0, 1, 10) @@ -2398,19 +2477,25 @@ def test_pynapple_support( assert np.all(y == y_nap.d) assert np.all(y_nap.t == x_nap[0].t) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [6, 7]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_basis_number( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj + basis_b_obj x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality @@ -2419,19 +2504,25 @@ def test_call_basis_number( == basis_a_obj.n_basis_funcs + basis_b_obj.n_basis_funcs ) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_non_empty( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj + basis_b_obj with pytest.raises(ValueError, match="All sample provided must"): @@ -2445,7 +2536,7 @@ def test_call_non_empty( (0.1, 2, does_not_raise()), ], ) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -2460,7 +2551,7 @@ def test_call_sample_range( mx, expectation, window_size, - class_specific_params, + basis_class_specific_params, ): if expectation == "check": if ( @@ -2473,10 +2564,10 @@ def test_call_sample_range( else: expectation = does_not_raise() basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj + basis_b_obj with expectation: @@ -2487,16 +2578,16 @@ def test_call_sample_range( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_fit_kernel( - self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, basis_class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) bas = basis_a_obj + basis_b_obj - bas.set_kernel() + bas.setup_basis(*([np.ones(10)] * bas._n_input_dimensionality)) def check_kernel(basis_obj): has_kern = [] @@ -2516,13 +2607,13 @@ def check_kernel(basis_obj): @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_transform_fails( - self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, basis_class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) bas = basis_a_obj + basis_b_obj if "Eval" in basis_a.__name__ and "Eval" in basis_b.__name__: @@ -2530,7 +2621,7 @@ def test_transform_fails( else: context = pytest.raises( ValueError, - match="You must call `_set_kernel` before `_compute_features`", + match="You must call `setup_basis` before `_compute_features`", ) with context: x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality @@ -2554,11 +2645,11 @@ def test_set_num_basis_input(self, n_basis_input1, n_basis_input2): bas1 = basis.RaisedCosineLinearConv(10, window_size=10) bas2 = basis.BSplineConv(10, window_size=10) bas_add = bas1 + bas2 - assert bas_add.n_basis_input is None + assert bas_add.n_basis_input_ is None bas_add.compute_features( np.ones((20, n_basis_input1)), np.ones((20, n_basis_input2)) ) - assert bas_add.n_basis_input == (n_basis_input1, n_basis_input2) + assert bas_add.n_basis_input_ == (n_basis_input1, n_basis_input2) @pytest.mark.parametrize( "n_input, expectation", @@ -2594,16 +2685,16 @@ def test_set_input_shape_type_1d_arrays( basis_b, shape_a, shape_b, - class_specific_params, + basis_class_specific_params, add_shape_a, add_shape_b, ): x = (np.ones((10, *add_shape_a)), np.ones((10, *add_shape_b))) basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) add = basis_a + basis_b @@ -2633,16 +2724,16 @@ def test_set_input_shape_type_2d_arrays( basis_b, shape_a, shape_b, - class_specific_params, + basis_class_specific_params, add_shape_a, add_shape_b, ): x = (np.ones((10, 2, *add_shape_a)), np.ones((10, 3, *add_shape_b))) basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) add = basis_a + basis_b @@ -2672,16 +2763,16 @@ def test_set_input_shape_type_nd_arrays( basis_b, shape_a, shape_b, - class_specific_params, + basis_class_specific_params, add_shape_a, add_shape_b, ): x = (np.ones((10, 2, 2, *add_shape_a)), np.ones((10, 3, 1, *add_shape_b))) basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) add = basis_a + basis_b @@ -2720,18 +2811,103 @@ def test_set_input_shape_type_nd_arrays( "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") ) def test_set_input_value_types( - self, inp_shape, expectation, basis_a, basis_b, class_specific_params + self, inp_shape, expectation, basis_a, basis_b, basis_class_specific_params ): basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) add = basis_a + basis_b with expectation: add.set_input_shape(*inp_shape) + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + def test_deep_copy_basis(self, basis_a, basis_b, basis_class_specific_params): + basis_a = self.instantiate_basis( + 5, basis_a, basis_class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, basis_class_specific_params, window_size=10 + ) + add = basis_a + basis_b + # test pointing to different objects + assert id(add.basis1) != id(basis_a) + assert id(add.basis1) != id(basis_b) + assert id(add.basis2) != id(basis_a) + assert id(add.basis2) != id(basis_b) + + # test attributes are not related + basis_a.n_basis_funcs = 10 + basis_b.n_basis_funcs = 10 + assert add.basis1.n_basis_funcs == 5 + assert add.basis2.n_basis_funcs == 5 + + add.basis1.n_basis_funcs = 6 + add.basis2.n_basis_funcs = 6 + assert basis_a.n_basis_funcs == 10 + assert basis_b.n_basis_funcs == 10 + + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + def test_compute_n_basis_runtime( + self, basis_a, basis_b, basis_class_specific_params + ): + basis_a = self.instantiate_basis( + 5, basis_a, basis_class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, basis_class_specific_params, window_size=10 + ) + add = basis_a + basis_b + add.basis1.n_basis_funcs = 10 + assert add.n_basis_funcs == 15 + add.basis2.n_basis_funcs = 10 + assert add.n_basis_funcs == 20 + + @pytest.mark.parametrize("basis_a", list_all_basis_classes()) + @pytest.mark.parametrize("basis_b", list_all_basis_classes()) + def test_runtime_n_basis_out_compute( + self, basis_a, basis_b, basis_class_specific_params + ): + basis_a = self.instantiate_basis( + 5, basis_a, basis_class_specific_params, window_size=10 + ) + basis_a.set_input_shape( + *([1] * basis_a._n_input_dimensionality) + ).to_transformer() + basis_b = self.instantiate_basis( + 5, basis_b, basis_class_specific_params, window_size=10 + ) + basis_b.set_input_shape( + *([1] * basis_b._n_input_dimensionality) + ).to_transformer() + add = basis_a + basis_b + inps_a = [2] * basis_a._n_input_dimensionality + add.basis1.set_input_shape(*inps_a) + if isinstance(basis_a, MultiplicativeBasis): + new_out_num = np.prod(inps_a) * add.basis1.n_basis_funcs + else: + new_out_num = inps_a[0] * add.basis1.n_basis_funcs + assert add.n_output_features == new_out_num + add.basis2.n_basis_funcs + inps_b = [3] * basis_b._n_input_dimensionality + if isinstance(basis_b, MultiplicativeBasis): + new_out_num_b = np.prod(inps_b) * add.basis2.n_basis_funcs + else: + new_out_num_b = inps_b[0] * add.basis2.n_basis_funcs + add.basis2.set_input_shape(*inps_b) + assert add.n_output_features == new_out_num + new_out_num_b + class TestMultiplicativeBasis(CombinedBasis): cls = {"eval": MultiplicativeBasis, "conv": MultiplicativeBasis} @@ -2781,7 +2957,7 @@ def test_compute_features_returns_expected_number_of_basis( basis_a, basis_b, window_size, - class_specific_params, + basis_class_specific_params, ): """ Test whether the evaluation of the `MultiplicativeBasis` results in a number of basis @@ -2789,10 +2965,10 @@ def test_compute_features_returns_expected_number_of_basis( """ # define the two basis basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj @@ -2821,17 +2997,17 @@ def test_sample_size_of_compute_features_matches_that_of_input( basis_a, basis_b, window_size, - class_specific_params, + basis_class_specific_params, ): """ Test whether the output sample size from the `MultiplicativeBasis` fit_transform function matches the input sample size. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj eval_basis = basis_obj.compute_features( @@ -2858,17 +3034,17 @@ def test_number_of_required_inputs_compute_features( basis_a, basis_b, window_size, - class_specific_params, + basis_class_specific_params, ): """ Test whether the number of required inputs for the `compute_features` function matches the sum of the number of input samples from the two bases. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj required_dim = ( @@ -2890,16 +3066,22 @@ def test_number_of_required_inputs_compute_features( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_meshgrid_size( - self, sample_size, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, + sample_size, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + basis_class_specific_params, ): """ Test whether the resulting meshgrid size matches the sample size input. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) basis_obj = basis_a_obj * basis_b_obj res = basis_obj.evaluate_on_grid( @@ -2914,16 +3096,22 @@ def test_evaluate_on_grid_meshgrid_size( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_basis_size( - self, sample_size, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, + sample_size, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + basis_class_specific_params, ): """ Test whether the number sample size output by evaluate_on_grid matches the sample size of the input. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) basis_obj = basis_a_obj * basis_b_obj eval_basis = basis_obj.evaluate_on_grid( @@ -2937,17 +3125,23 @@ def test_evaluate_on_grid_basis_size( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_input_number( - self, n_input, basis_a, basis_b, n_basis_a, n_basis_b, class_specific_params + self, + n_input, + basis_a, + basis_b, + n_basis_a, + n_basis_b, + basis_class_specific_params, ): """ Test whether the number of inputs provided to `evaluate_on_grid` matches the sum of the number of input samples required from each of the basis objects. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) basis_obj = basis_a_obj * basis_b_obj inputs = [20] * n_input @@ -2977,15 +3171,15 @@ def test_inconsistent_sample_sizes( n_basis_b, sample_size_a, sample_size_b, - class_specific_params, + basis_class_specific_params, ): """Test that the inputs of inconsistent sample sizes result in an exception when compute_features is called""" raise_exception = sample_size_a != sample_size_b basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) input_a = [ np.linspace(0, 1, sample_size_a) @@ -3009,7 +3203,13 @@ def test_inconsistent_sample_sizes( @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) def test_pynapple_support_compute_features( - self, basis_a, basis_b, n_basis_a, n_basis_b, sample_size, class_specific_params + self, + basis_a, + basis_b, + n_basis_a, + n_basis_b, + sample_size, + basis_class_specific_params, ): iset = nap.IntervalSet(start=[0, 0.5], end=[0.49999, 1]) inp = nap.Tsd( @@ -3018,9 +3218,9 @@ def test_pynapple_support_compute_features( time_support=iset, ) basis_prod = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) * self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) out = basis_prod.compute_features(*([inp] * basis_prod._n_input_dimensionality)) assert isinstance(out, nap.TsdFrame) @@ -3032,7 +3232,7 @@ def test_pynapple_support_compute_features( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) @pytest.mark.parametrize("num_input", [0, 1, 2, 3, 4, 5]) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) def test_call_input_num( self, n_basis_a, @@ -3041,13 +3241,13 @@ def test_call_input_num( basis_b, num_input, window_size, - class_specific_params, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj if num_input == basis_obj._n_input_dimensionality: @@ -3066,7 +3266,7 @@ def test_call_input_num( (np.linspace(0, 1, 10)[:, None], pytest.raises(ValueError)), ], ) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -3080,20 +3280,20 @@ def test_call_input_shape( inp, window_size, expectation, - class_specific_params, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj with expectation: basis_obj._evaluate(*([inp] * basis_obj._n_input_dimensionality)) @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -3106,25 +3306,31 @@ def test_call_sample_axis( basis_b, time_axis_shape, window_size, - class_specific_params, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj inp = [np.linspace(0, 1, time_axis_shape)] * basis_obj._n_input_dimensionality assert basis_obj._evaluate(*inp).shape[0] == time_axis_shape - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_nan( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): if ( basis_a == basis.OrthExponentialBasis @@ -3132,10 +3338,10 @@ def test_call_nan( ): return basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj inp = [np.linspace(0, 1, 10)] * basis_obj._n_input_dimensionality @@ -3148,40 +3354,46 @@ def test_call_nan( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_equivalent_in_conv( - self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, basis_class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) bas_eva = basis_a_obj * basis_b_obj basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=8 + n_basis_a, basis_a, basis_class_specific_params, window_size=8 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=8 + n_basis_b, basis_b, basis_class_specific_params, window_size=8 ) bas_con = basis_a_obj * basis_b_obj x = [np.linspace(0, 1, 10)] * bas_con._n_input_dimensionality assert np.all(bas_con._evaluate(*x) == bas_eva._evaluate(*x)) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_pynapple_support( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj * basis_b_obj x = np.linspace(0, 1, 10) @@ -3193,19 +3405,25 @@ def test_pynapple_support( assert np.all(y == y_nap.d) assert np.all(y_nap.t == x_nap[0].t) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [6, 7]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_basis_number( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj * basis_b_obj x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality @@ -3214,19 +3432,25 @@ def test_call_basis_number( == basis_a_obj.n_basis_funcs * basis_b_obj.n_basis_funcs ) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_non_empty( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj * basis_b_obj with pytest.raises(ValueError, match="All sample provided must"): @@ -3240,7 +3464,7 @@ def test_call_non_empty( (0.1, 2, does_not_raise()), ], ) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -3255,7 +3479,7 @@ def test_call_sample_range( mx, expectation, window_size, - class_specific_params, + basis_class_specific_params, ): if expectation == "check": if ( @@ -3268,10 +3492,10 @@ def test_call_sample_range( else: expectation = does_not_raise() basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj * basis_b_obj with expectation: @@ -3282,16 +3506,16 @@ def test_call_sample_range( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_fit_kernel( - self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, basis_class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) bas = basis_a_obj * basis_b_obj - bas.set_kernel() + bas._set_input_independent_states() def check_kernel(basis_obj): has_kern = [] @@ -3311,13 +3535,13 @@ def check_kernel(basis_obj): @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_transform_fails( - self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, basis_class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) bas = basis_a_obj * basis_b_obj if "Eval" in basis_a.__name__ and "Eval" in basis_b.__name__: @@ -3325,7 +3549,7 @@ def test_transform_fails( else: context = pytest.raises( ValueError, - match="You must call `_set_kernel` before `_compute_features`", + match="You must call `setup_basis` before `_compute_features`", ) with context: x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality @@ -3349,11 +3573,11 @@ def test_set_num_basis_input(self, n_basis_input1, n_basis_input2): bas1 = basis.RaisedCosineLinearConv(10, window_size=10) bas2 = basis.BSplineConv(10, window_size=10) bas_add = bas1 * bas2 - assert bas_add.n_basis_input is None + assert bas_add.n_basis_input_ is None bas_add.compute_features( np.ones((20, n_basis_input1)), np.ones((20, n_basis_input2)) ) - assert bas_add.n_basis_input == (n_basis_input1, n_basis_input2) + assert bas_add.n_basis_input_ == (n_basis_input1, n_basis_input2) @pytest.mark.parametrize( "n_input, expectation", @@ -3375,14 +3599,14 @@ def test_expected_input_number(self, n_input, expectation): @pytest.mark.parametrize("n_basis_input1", [1, 2, 3]) @pytest.mark.parametrize("n_basis_input2", [1, 2, 3]) - def test_n_basis_input(self, n_basis_input1, n_basis_input2): + def test_n_basis_input_(self, n_basis_input1, n_basis_input2): bas1 = basis.RaisedCosineLinearConv(10, window_size=10) bas2 = basis.BSplineConv(10, window_size=10) bas_prod = bas1 * bas2 bas_prod.compute_features( np.ones((20, n_basis_input1)), np.ones((20, n_basis_input2)) ) - assert bas_prod.n_basis_input == (n_basis_input1, n_basis_input2) + assert bas_prod.n_basis_input_ == (n_basis_input1, n_basis_input2) @pytest.mark.parametrize( "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") @@ -3400,16 +3624,16 @@ def test_set_input_shape_type_1d_arrays( basis_b, shape_a, shape_b, - class_specific_params, + basis_class_specific_params, add_shape_a, add_shape_b, ): x = (np.ones((10, *add_shape_a)), np.ones((10, *add_shape_b))) basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) mul = basis_a * basis_b @@ -3439,16 +3663,16 @@ def test_set_input_shape_type_2d_arrays( basis_b, shape_a, shape_b, - class_specific_params, + basis_class_specific_params, add_shape_a, add_shape_b, ): x = (np.ones((10, 2, *add_shape_a)), np.ones((10, 3, *add_shape_b))) basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) mul = basis_a * basis_b @@ -3478,16 +3702,16 @@ def test_set_input_shape_type_nd_arrays( basis_b, shape_a, shape_b, - class_specific_params, + basis_class_specific_params, add_shape_a, add_shape_b, ): x = (np.ones((10, 2, 2, *add_shape_a)), np.ones((10, 3, 1, *add_shape_b))) basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) mul = basis_a * basis_b @@ -3526,24 +3750,109 @@ def test_set_input_shape_type_nd_arrays( "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") ) def test_set_input_value_types( - self, inp_shape, expectation, basis_a, basis_b, class_specific_params + self, inp_shape, expectation, basis_a, basis_b, basis_class_specific_params ): basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) mul = basis_a * basis_b with expectation: mul.set_input_shape(*inp_shape) + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + def test_deep_copy_basis(self, basis_a, basis_b, basis_class_specific_params): + basis_a = self.instantiate_basis( + 5, basis_a, basis_class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, basis_class_specific_params, window_size=10 + ) + mul = basis_a * basis_b + # test pointing to different objects + assert id(mul.basis1) != id(basis_a) + assert id(mul.basis1) != id(basis_b) + assert id(mul.basis2) != id(basis_a) + assert id(mul.basis2) != id(basis_b) + + # test attributes are not related + basis_a.n_basis_funcs = 10 + basis_b.n_basis_funcs = 10 + assert mul.basis1.n_basis_funcs == 5 + assert mul.basis2.n_basis_funcs == 5 + + mul.basis1.n_basis_funcs = 6 + mul.basis2.n_basis_funcs = 6 + assert basis_a.n_basis_funcs == 10 + assert basis_b.n_basis_funcs == 10 + + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + def test_compute_n_basis_runtime( + self, basis_a, basis_b, basis_class_specific_params + ): + basis_a = self.instantiate_basis( + 5, basis_a, basis_class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, basis_class_specific_params, window_size=10 + ) + mul = basis_a * basis_b + mul.basis1.n_basis_funcs = 10 + assert mul.n_basis_funcs == 50 + mul.basis2.n_basis_funcs = 10 + assert mul.n_basis_funcs == 100 + + @pytest.mark.parametrize("basis_a", list_all_basis_classes()) + @pytest.mark.parametrize("basis_b", list_all_basis_classes()) + def test_runtime_n_basis_out_compute( + self, basis_a, basis_b, basis_class_specific_params + ): + basis_a = self.instantiate_basis( + 5, basis_a, basis_class_specific_params, window_size=10 + ) + basis_a.set_input_shape( + *([1] * basis_a._n_input_dimensionality) + ).to_transformer() + basis_b = self.instantiate_basis( + 5, basis_b, basis_class_specific_params, window_size=10 + ) + basis_b.set_input_shape( + *([1] * basis_b._n_input_dimensionality) + ).to_transformer() + mul = basis_a * basis_b + inps_a = [2] * basis_a._n_input_dimensionality + mul.basis1.set_input_shape(*inps_a) + if isinstance(basis_a, MultiplicativeBasis): + new_out_num = np.prod(inps_a) * mul.basis1.n_basis_funcs + else: + new_out_num = inps_a[0] * mul.basis1.n_basis_funcs + assert mul.n_output_features == new_out_num * mul.basis2.n_basis_funcs + inps_b = [3] * basis_b._n_input_dimensionality + if isinstance(basis_b, MultiplicativeBasis): + new_out_num_b = np.prod(inps_b) * mul.basis2.n_basis_funcs + else: + new_out_num_b = inps_b[0] * mul.basis2.n_basis_funcs + mul.basis2.set_input_shape(*inps_b) + assert mul.n_output_features == new_out_num * new_out_num_b + @pytest.mark.parametrize( "exponent", [-1, 0, 0.5, basis.RaisedCosineLogEval(4), 1, 2, 3] ) @pytest.mark.parametrize("basis_class", list_all_basis_classes()) -def test_power_of_basis(exponent, basis_class, class_specific_params): +def test_power_of_basis(exponent, basis_class, basis_class_specific_params): """Test if the power behaves as expected.""" raise_exception_type = not type(exponent) is int @@ -3553,7 +3862,7 @@ def test_power_of_basis(exponent, basis_class, class_specific_params): raise_exception_value = False basis_obj = CombinedBasis.instantiate_basis( - 5, basis_class, class_specific_params, window_size=10 + 5, basis_class, basis_class_specific_params, window_size=10 ) if raise_exception_type: @@ -3589,13 +3898,14 @@ def test_power_of_basis(exponent, basis_class, class_specific_params): "basis_cls", list_all_basis_classes(), ) -def test_basis_to_transformer(basis_cls, class_specific_params): +def test_basis_to_transformer(basis_cls, basis_class_specific_params): n_basis_funcs = 5 bas = CombinedBasis().instantiate_basis( - n_basis_funcs, basis_cls, class_specific_params, window_size=10 + n_basis_funcs, basis_cls, basis_class_specific_params, window_size=10 ) - - trans_bas = bas.to_transformer() + trans_bas = bas.set_input_shape( + *([1] * bas._n_input_dimensionality) + ).to_transformer() assert isinstance(trans_bas, basis.TransformerBasis) @@ -3607,386 +3917,6 @@ def test_basis_to_transformer(basis_cls, class_specific_params): assert np.all(getattr(bas, k) == getattr(trans_bas, k)) -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -def test_transformer_has_the_same_public_attributes_as_basis( - basis_cls, class_specific_params -): - n_basis_funcs = 5 - bas = CombinedBasis().instantiate_basis( - n_basis_funcs, basis_cls, class_specific_params, window_size=10 - ) - - public_attrs_basis = {attr for attr in dir(bas) if not attr.startswith("_")} - public_attrs_transformerbasis = { - attr for attr in dir(bas.to_transformer()) if not attr.startswith("_") - } - - assert public_attrs_transformerbasis - public_attrs_basis == { - "fit", - "fit_transform", - "transform", - } - - assert public_attrs_basis - public_attrs_transformerbasis == set() - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), -) -def test_to_transformer_and_constructor_are_equivalent( - basis_cls, class_specific_params -): - n_basis_funcs = 5 - bas = CombinedBasis().instantiate_basis( - n_basis_funcs, basis_cls, class_specific_params, window_size=10 - ) - - trans_bas_a = bas.to_transformer() - trans_bas_b = basis.TransformerBasis(bas) - - # they both just have a _basis - assert ( - list(trans_bas_a.__dict__.keys()) - == list(trans_bas_b.__dict__.keys()) - == ["_basis"] - ) - # and those bases are the same - assert np.all( - trans_bas_a._basis.__dict__.pop("_decay_rates", 1) - == trans_bas_b._basis.__dict__.pop("_decay_rates", 1) - ) - assert trans_bas_a._basis.__dict__ == trans_bas_b._basis.__dict__ - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -def test_basis_to_transformer_makes_a_copy(basis_cls, class_specific_params): - bas_a = CombinedBasis().instantiate_basis( - 5, basis_cls, class_specific_params, window_size=10 - ) - trans_bas_a = bas_a.to_transformer() - - # changing an attribute in bas should not change trans_bas - if basis_cls in [AdditiveBasis, MultiplicativeBasis]: - bas_a._basis1.n_basis_funcs = 10 - assert trans_bas_a._basis._basis1.n_basis_funcs == 5 - - # changing an attribute in the transformer basis should not change the original - bas_b = CombinedBasis().instantiate_basis( - 5, basis_cls, class_specific_params, window_size=10 - ) - trans_bas_b = bas_b.to_transformer() - trans_bas_b._basis._basis1.n_basis_funcs = 100 - assert bas_b._basis1.n_basis_funcs == 5 - else: - bas_a.n_basis_funcs = 10 - assert trans_bas_a.n_basis_funcs == 5 - - # changing an attribute in the transformer basis should not change the original - bas_b = CombinedBasis().instantiate_basis( - 5, basis_cls, class_specific_params, window_size=10 - ) - trans_bas_b = bas_b.to_transformer() - trans_bas_b.n_basis_funcs = 100 - assert bas_b.n_basis_funcs == 5 - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -@pytest.mark.parametrize("n_basis_funcs", [5, 10, 20]) -def test_transformerbasis_getattr(basis_cls, n_basis_funcs, class_specific_params): - trans_basis = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - n_basis_funcs, basis_cls, class_specific_params, window_size=10 - ) - ) - if basis_cls in [AdditiveBasis, MultiplicativeBasis]: - for bas in [ - getattr(trans_basis._basis, attr) for attr in ("_basis1", "_basis2") - ]: - assert bas.n_basis_funcs == n_basis_funcs - else: - assert trans_basis.n_basis_funcs == n_basis_funcs - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), -) -@pytest.mark.parametrize("n_basis_funcs_init", [5]) -@pytest.mark.parametrize("n_basis_funcs_new", [6, 10, 20]) -def test_transformerbasis_set_params( - basis_cls, n_basis_funcs_init, n_basis_funcs_new, class_specific_params -): - trans_basis = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - n_basis_funcs_init, basis_cls, class_specific_params, window_size=10 - ) - ) - trans_basis.set_params(n_basis_funcs=n_basis_funcs_new) - - assert trans_basis.n_basis_funcs == n_basis_funcs_new - assert trans_basis._basis.n_basis_funcs == n_basis_funcs_new - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), -) -def test_transformerbasis_setattr_basis(basis_cls, class_specific_params): - # setting the _basis attribute should change it - trans_bas = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - 10, basis_cls, class_specific_params, window_size=10 - ) - ) - trans_bas._basis = CombinedBasis().instantiate_basis( - 20, basis_cls, class_specific_params, window_size=10 - ) - - assert trans_bas.n_basis_funcs == 20 - assert trans_bas._basis.n_basis_funcs == 20 - assert isinstance(trans_bas._basis, basis_cls) - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), -) -def test_transformerbasis_setattr_basis_attribute(basis_cls, class_specific_params): - # setting an attribute that is an attribute of the underlying _basis - # should propagate setting it on _basis itself - trans_bas = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - 10, basis_cls, class_specific_params, window_size=10 - ) - ) - trans_bas.n_basis_funcs = 20 - - assert trans_bas.n_basis_funcs == 20 - assert trans_bas._basis.n_basis_funcs == 20 - assert isinstance(trans_bas._basis, basis_cls) - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), -) -def test_transformerbasis_copy_basis_on_contsruct(basis_cls, class_specific_params): - # modifying the transformerbasis's attributes shouldn't - # touch the original basis that was used to create it - orig_bas = CombinedBasis().instantiate_basis( - 10, basis_cls, class_specific_params, window_size=10 - ) - trans_bas = basis.TransformerBasis(orig_bas) - trans_bas.n_basis_funcs = 20 - - assert orig_bas.n_basis_funcs == 10 - assert trans_bas._basis.n_basis_funcs == 20 - assert trans_bas._basis.n_basis_funcs == 20 - assert isinstance(trans_bas._basis, basis_cls) - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -def test_transformerbasis_setattr_illegal_attribute(basis_cls, class_specific_params): - # changing an attribute that is not _basis or an attribute of _basis - # is not allowed - trans_bas = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - 10, basis_cls, class_specific_params, window_size=10 - ) - ) - - with pytest.raises( - ValueError, - match="Only setting _basis or existing attributes of _basis is allowed.", - ): - trans_bas.random_attr = "random value" - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -def test_transformerbasis_addition(basis_cls, class_specific_params): - n_basis_funcs_a = 5 - n_basis_funcs_b = n_basis_funcs_a * 2 - bas_a = CombinedBasis().instantiate_basis( - n_basis_funcs_a, basis_cls, class_specific_params, window_size=10 - ) - bas_b = CombinedBasis().instantiate_basis( - n_basis_funcs_b, basis_cls, class_specific_params, window_size=10 - ) - trans_bas_a = basis.TransformerBasis(bas_a) - trans_bas_b = basis.TransformerBasis(bas_b) - trans_bas_sum = trans_bas_a + trans_bas_b - assert isinstance(trans_bas_sum, basis.TransformerBasis) - assert isinstance(trans_bas_sum._basis, AdditiveBasis) - assert ( - trans_bas_sum.n_basis_funcs - == trans_bas_a.n_basis_funcs + trans_bas_b.n_basis_funcs - ) - assert ( - trans_bas_sum._n_input_dimensionality - == trans_bas_a._n_input_dimensionality + trans_bas_b._n_input_dimensionality - ) - if basis_cls not in [AdditiveBasis, MultiplicativeBasis]: - assert trans_bas_sum._basis1.n_basis_funcs == n_basis_funcs_a - assert trans_bas_sum._basis2.n_basis_funcs == n_basis_funcs_b - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -def test_transformerbasis_multiplication(basis_cls, class_specific_params): - n_basis_funcs_a = 5 - n_basis_funcs_b = n_basis_funcs_a * 2 - trans_bas_a = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - n_basis_funcs_a, basis_cls, class_specific_params, window_size=10 - ) - ) - trans_bas_b = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - n_basis_funcs_b, basis_cls, class_specific_params, window_size=10 - ) - ) - trans_bas_prod = trans_bas_a * trans_bas_b - assert isinstance(trans_bas_prod, basis.TransformerBasis) - assert isinstance(trans_bas_prod._basis, MultiplicativeBasis) - assert ( - trans_bas_prod.n_basis_funcs - == trans_bas_a.n_basis_funcs * trans_bas_b.n_basis_funcs - ) - assert ( - trans_bas_prod._n_input_dimensionality - == trans_bas_a._n_input_dimensionality + trans_bas_b._n_input_dimensionality - ) - if basis_cls not in [AdditiveBasis, MultiplicativeBasis]: - assert trans_bas_prod._basis1.n_basis_funcs == n_basis_funcs_a - assert trans_bas_prod._basis2.n_basis_funcs == n_basis_funcs_b - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -@pytest.mark.parametrize( - "exponent, error_type, error_message", - [ - (2, does_not_raise, None), - (5, does_not_raise, None), - (0.5, TypeError, "Exponent should be an integer"), - (-1, ValueError, "Exponent should be a non-negative integer"), - ], -) -def test_transformerbasis_exponentiation( - basis_cls, exponent: int, error_type, error_message, class_specific_params -): - trans_bas = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - 5, basis_cls, class_specific_params, window_size=10 - ) - ) - - if not isinstance(exponent, int): - with pytest.raises(error_type, match=error_message): - trans_bas_exp = trans_bas**exponent - assert isinstance(trans_bas_exp, basis.TransformerBasis) - assert isinstance(trans_bas_exp._basis, MultiplicativeBasis) - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -def test_transformerbasis_dir(basis_cls, class_specific_params): - trans_bas = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - 5, basis_cls, class_specific_params, window_size=10 - ) - ) - for attr_name in ( - "fit", - "transform", - "fit_transform", - "n_basis_funcs", - "mode", - "window_size", - ): - if ( - attr_name == "window_size" - and "Conv" not in trans_bas._basis.__class__.__name__ - ): - continue - assert attr_name in dir(trans_bas) - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes("Conv"), -) -def test_transformerbasis_sk_clone_kernel_noned(basis_cls, class_specific_params): - orig_bas = CombinedBasis().instantiate_basis( - 10, basis_cls, class_specific_params, window_size=20 - ) - trans_bas = basis.TransformerBasis(orig_bas) - - # kernel should be saved in the object after fit - trans_bas.fit(np.random.randn(100, 20)) - assert isinstance(trans_bas.kernel_, np.ndarray) - - # cloning should set kernel_ to None - trans_bas_clone = sk_clone(trans_bas) - - # the original object should still have kernel_ - assert isinstance(trans_bas.kernel_, np.ndarray) - # but the clone should not have one - assert trans_bas_clone.kernel_ is None - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -@pytest.mark.parametrize("n_basis_funcs", [5]) -def test_transformerbasis_pickle( - tmpdir, basis_cls, n_basis_funcs, class_specific_params -): - # the test that tries cross-validation with n_jobs = 2 already should test this - trans_bas = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - n_basis_funcs, basis_cls, class_specific_params, window_size=10 - ) - ) - filepath = tmpdir / "transformerbasis.pickle" - with open(filepath, "wb") as f: - pickle.dump(trans_bas, f) - with open(filepath, "rb") as f: - trans_bas2 = pickle.load(f) - - assert isinstance(trans_bas2, basis.TransformerBasis) - if basis_cls in [AdditiveBasis, MultiplicativeBasis]: - for bas in [ - getattr(trans_bas2._basis, attr) for attr in ("_basis1", "_basis2") - ]: - assert bas.n_basis_funcs == n_basis_funcs - else: - assert trans_bas2.n_basis_funcs == n_basis_funcs - - @pytest.mark.parametrize( "tsd", [ @@ -4025,7 +3955,7 @@ def test_multi_epoch_pynapple_basis( shift, predictor_causality, nan_index, - class_specific_params, + basis_class_specific_params, ): """Test nan location in multi-epoch pynapple tsd.""" kwargs = dict( @@ -4039,7 +3969,11 @@ def test_multi_epoch_pynapple_basis( else: nbasis = 5 bas = CombinedBasis().instantiate_basis( - nbasis, basis_cls, class_specific_params, window_size=window_size, **kwargs + nbasis, + basis_cls, + basis_class_specific_params, + window_size=window_size, + **kwargs, ) n_input = bas._n_input_dimensionality @@ -4092,7 +4026,7 @@ def test_multi_epoch_pynapple_basis_transformer( shift, predictor_causality, nan_index, - class_specific_params, + basis_class_specific_params, ): """Test nan location in multi-epoch pynapple tsd.""" kwargs = dict( @@ -4106,18 +4040,22 @@ def test_multi_epoch_pynapple_basis_transformer( nbasis = 5 bas = CombinedBasis().instantiate_basis( - nbasis, basis_cls, class_specific_params, window_size=window_size, **kwargs + nbasis, + basis_cls, + basis_class_specific_params, + window_size=window_size, + **kwargs, ) n_input = bas._n_input_dimensionality - # pass through transformer - bas = basis.TransformerBasis(bas) - # concat input X = pynapple_concatenate_numpy([tsd[:, None]] * n_input, axis=1) # run convolutions + # pass through transformer + bas.set_input_shape(X) + bas = basis.TransformerBasis(bas) res = bas.fit_transform(X) # check nans @@ -4140,18 +4078,18 @@ def test_multi_epoch_pynapple_basis_transformer( "__add__", "__add__", lambda bas1, bas2, bas3: { - "1": slice(0, bas1._n_basis_input[0] * bas1.n_basis_funcs), + "1": slice(0, bas1._n_basis_input_[0] * bas1.n_basis_funcs), "2": slice( - bas1._n_basis_input[0] * bas1.n_basis_funcs, - bas1._n_basis_input[0] * bas1.n_basis_funcs - + bas2._n_basis_input[0] * bas2.n_basis_funcs, + bas1._n_basis_input_[0] * bas1.n_basis_funcs, + bas1._n_basis_input_[0] * bas1.n_basis_funcs + + bas2._n_basis_input_[0] * bas2.n_basis_funcs, ), "3": slice( - bas1._n_basis_input[0] * bas1.n_basis_funcs - + bas2._n_basis_input[0] * bas2.n_basis_funcs, - bas1._n_basis_input[0] * bas1.n_basis_funcs - + bas2._n_basis_input[0] * bas2.n_basis_funcs - + bas3._n_basis_input[0] * bas3.n_basis_funcs, + bas1._n_basis_input_[0] * bas1.n_basis_funcs + + bas2._n_basis_input_[0] * bas2.n_basis_funcs, + bas1._n_basis_input_[0] * bas1.n_basis_funcs + + bas2._n_basis_input_[0] * bas2.n_basis_funcs + + bas3._n_basis_input_[0] * bas3.n_basis_funcs, ), }, ), @@ -4159,13 +4097,13 @@ def test_multi_epoch_pynapple_basis_transformer( "__add__", "__mul__", lambda bas1, bas2, bas3: { - "1": slice(0, bas1._n_basis_input[0] * bas1.n_basis_funcs), + "1": slice(0, bas1._n_basis_input_[0] * bas1.n_basis_funcs), "(2 * 3)": slice( - bas1._n_basis_input[0] * bas1.n_basis_funcs, - bas1._n_basis_input[0] * bas1.n_basis_funcs - + bas2._n_basis_input[0] + bas1._n_basis_input_[0] * bas1.n_basis_funcs, + bas1._n_basis_input_[0] * bas1.n_basis_funcs + + bas2._n_basis_input_[0] * bas2.n_basis_funcs - * bas3._n_basis_input[0] + * bas3._n_basis_input_[0] * bas3.n_basis_funcs, ), }, @@ -4177,11 +4115,11 @@ def test_multi_epoch_pynapple_basis_transformer( # note that it doesn't respect algebra order but execute right to left (first add then multiplies) "(1 * (2 + 3))": slice( 0, - bas1._n_basis_input[0] + bas1._n_basis_input_[0] * bas1.n_basis_funcs * ( - bas2._n_basis_input[0] * bas2.n_basis_funcs - + bas3._n_basis_input[0] * bas3.n_basis_funcs + bas2._n_basis_input_[0] * bas2.n_basis_funcs + + bas3._n_basis_input_[0] * bas3.n_basis_funcs ), ), }, @@ -4192,11 +4130,11 @@ def test_multi_epoch_pynapple_basis_transformer( lambda bas1, bas2, bas3: { "(1 * (2 * 3))": slice( 0, - bas1._n_basis_input[0] + bas1._n_basis_input_[0] * bas1.n_basis_funcs - * bas2._n_basis_input[0] + * bas2._n_basis_input_[0] * bas2.n_basis_funcs - * bas3._n_basis_input[0] + * bas3._n_basis_input_[0] * bas3.n_basis_funcs, ), }, @@ -4204,7 +4142,7 @@ def test_multi_epoch_pynapple_basis_transformer( ], ) def test__get_splitter( - bas1, bas2, bas3, operator1, operator2, compute_slice, class_specific_params + bas1, bas2, bas3, operator1, operator2, compute_slice, basis_class_specific_params ): # skip nested if any( @@ -4218,13 +4156,22 @@ def test__get_splitter( combine_basis = CombinedBasis() bas1_instance = combine_basis.instantiate_basis( - n_basis[0], bas1, class_specific_params, window_size=10, label="1" + n_basis[0], bas1, basis_class_specific_params, window_size=10, label="1" + ) + bas1_instance.set_input_shape( + *([n_input_basis[0]] * bas1_instance._n_input_dimensionality) ) bas2_instance = combine_basis.instantiate_basis( - n_basis[1], bas2, class_specific_params, window_size=10, label="2" + n_basis[1], bas2, basis_class_specific_params, window_size=10, label="2" + ) + bas2_instance.set_input_shape( + *([n_input_basis[1]] * bas2_instance._n_input_dimensionality) ) bas3_instance = combine_basis.instantiate_basis( - n_basis[2], bas3, class_specific_params, window_size=10, label="3" + n_basis[2], bas3, basis_class_specific_params, window_size=10, label="3" + ) + bas3_instance.set_input_shape( + *([n_input_basis[2]] * bas3_instance._n_input_dimensionality) ) func1 = getattr(bas1_instance, operator1) @@ -4250,11 +4197,11 @@ def test__get_splitter( 1, 1, lambda bas1, bas2: { - "1": slice(0, bas1._n_basis_input[0] * bas1.n_basis_funcs), + "1": slice(0, bas1._n_basis_input_[0] * bas1.n_basis_funcs), "2": slice( - bas1._n_basis_input[0] * bas1.n_basis_funcs, - bas1._n_basis_input[0] * bas1.n_basis_funcs - + bas2._n_basis_input[0] * bas2.n_basis_funcs, + bas1._n_basis_input_[0] * bas1.n_basis_funcs, + bas1._n_basis_input_[0] * bas1.n_basis_funcs + + bas2._n_basis_input_[0] * bas2.n_basis_funcs, ), }, ), @@ -4265,9 +4212,9 @@ def test__get_splitter( lambda bas1, bas2: { "(1 * 2)": slice( 0, - bas1._n_basis_input[0] + bas1._n_basis_input_[0] * bas1.n_basis_funcs - * bas2._n_basis_input[0] + * bas2._n_basis_input_[0] * bas2.n_basis_funcs, ) }, @@ -4292,7 +4239,7 @@ def test__get_splitter( 1, lambda bas1, bas2: { "(1 * 2)": slice( - 0, bas1._n_basis_input[0] * bas1.n_basis_funcs * bas2.n_basis_funcs + 0, bas1._n_basis_input_[0] * bas1.n_basis_funcs * bas2.n_basis_funcs ) }, ), @@ -4319,7 +4266,7 @@ def test__get_splitter( 2, lambda bas1, bas2: { "(1 * 2)": slice( - 0, bas2._n_basis_input[0] * bas1.n_basis_funcs * bas2.n_basis_funcs + 0, bas2._n_basis_input_[0] * bas1.n_basis_funcs * bas2.n_basis_funcs ) }, ), @@ -4361,7 +4308,7 @@ def test__get_splitter_split_by_input( n_input_basis_1, n_input_basis_2, compute_slice, - class_specific_params, + basis_class_specific_params, ): # skip nested if any( @@ -4373,10 +4320,17 @@ def test__get_splitter_split_by_input( n_basis = [5, 6] combine_basis = CombinedBasis() bas1_instance = combine_basis.instantiate_basis( - n_basis[0], bas1, class_specific_params, window_size=10, label="1" + n_basis[0], bas1, basis_class_specific_params, window_size=10, label="1" ) + bas1_instance.set_input_shape( + *([n_input_basis_1] * bas1_instance._n_input_dimensionality) + ) + bas2_instance = combine_basis.instantiate_basis( - n_basis[1], bas2, class_specific_params, window_size=10, label="2" + n_basis[1], bas2, basis_class_specific_params, window_size=10, label="2" + ) + bas2_instance.set_input_shape( + *([n_input_basis_2] * bas1_instance._n_input_dimensionality) ) func1 = getattr(bas1_instance, operator) @@ -4396,7 +4350,7 @@ def test__get_splitter_split_by_input( "bas1, bas2, bas3", list(itertools.product(*[list_all_basis_classes()] * 3)), ) -def test_duplicate_keys(bas1, bas2, bas3, class_specific_params): +def test_duplicate_keys(bas1, bas2, bas3, basis_class_specific_params): # skip nested if any( bas in (AdditiveBasis, MultiplicativeBasis, basis.TransformerBasis) @@ -4406,13 +4360,13 @@ def test_duplicate_keys(bas1, bas2, bas3, class_specific_params): combine_basis = CombinedBasis() bas1_instance = combine_basis.instantiate_basis( - 5, bas1, class_specific_params, window_size=10, label="label" + 5, bas1, basis_class_specific_params, window_size=10, label="label" ) bas2_instance = combine_basis.instantiate_basis( - 5, bas2, class_specific_params, window_size=10, label="label" + 5, bas2, basis_class_specific_params, window_size=10, label="label" ) bas3_instance = combine_basis.instantiate_basis( - 5, bas3, class_specific_params, window_size=10, label="label" + 5, bas3, basis_class_specific_params, window_size=10, label="label" ) bas_obj = bas1_instance + bas2_instance + bas3_instance @@ -4443,7 +4397,7 @@ def test_duplicate_keys(bas1, bas2, bas3, class_specific_params): ], ) def test_split_feature_axis( - bas1, bas2, x, axis, expectation, exp_shapes, class_specific_params + bas1, bas2, x, axis, expectation, exp_shapes, basis_class_specific_params ): # skip nested if any( @@ -4455,10 +4409,10 @@ def test_split_feature_axis( n_basis = [5, 6] combine_basis = CombinedBasis() bas1_instance = combine_basis.instantiate_basis( - n_basis[0], bas1, class_specific_params, window_size=10, label="1" + n_basis[0], bas1, basis_class_specific_params, window_size=10, label="1" ) bas2_instance = combine_basis.instantiate_basis( - n_basis[1], bas2, class_specific_params, window_size=10, label="2" + n_basis[1], bas2, basis_class_specific_params, window_size=10, label="2" ) bas = bas1_instance + bas2_instance diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 5e4ce13d..9e52a4f2 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -21,7 +21,7 @@ ) def test_sklearn_transformer_pipeline(bas, poissonGLM_model_instantiation): X, y, model, _, _ = poissonGLM_model_instantiation - bas = TransformerBasis(bas) + bas = TransformerBasis(bas.set_input_shape(*([1] * bas._n_input_dimensionality))) pipe = pipeline.Pipeline([("eval", bas), ("fit", model)]) pipe.fit(X[:, : bas._basis._n_input_dimensionality] ** 2, y) @@ -39,7 +39,7 @@ def test_sklearn_transformer_pipeline(bas, poissonGLM_model_instantiation): ) def test_sklearn_transformer_pipeline_cv(bas, poissonGLM_model_instantiation): X, y, model, _, _ = poissonGLM_model_instantiation - bas = TransformerBasis(bas) + bas = TransformerBasis(bas.set_input_shape(*([1] * bas._n_input_dimensionality))) pipe = pipeline.Pipeline([("basis", bas), ("fit", model)]) param_grid = dict(basis__n_basis_funcs=(4, 5, 10)) gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score="raise") @@ -60,7 +60,7 @@ def test_sklearn_transformer_pipeline_cv_multiprocess( bas, poissonGLM_model_instantiation ): X, y, model, _, _ = poissonGLM_model_instantiation - bas = TransformerBasis(bas) + bas = TransformerBasis(bas.set_input_shape(*([1] * bas._n_input_dimensionality))) pipe = pipeline.Pipeline([("basis", bas), ("fit", model)]) param_grid = dict(basis__n_basis_funcs=(4, 5, 10)) gridsearch = GridSearchCV( @@ -86,8 +86,15 @@ def test_sklearn_transformer_pipeline_cv_directly_over_basis( ): X, y, model, _, _ = poissonGLM_model_instantiation bas = TransformerBasis(bas_cls(5)) + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) pipe = pipeline.Pipeline([("transformerbasis", bas), ("fit", model)]) - param_grid = dict(transformerbasis___basis=(bas_cls(5), bas_cls(10), bas_cls(20))) + param_grid = dict( + transformerbasis___basis=( + bas_cls(5).set_input_shape(*([1] * bas._n_input_dimensionality)), + bas_cls(10).set_input_shape(*([1] * bas._n_input_dimensionality)), + bas_cls(20).set_input_shape(*([1] * bas._n_input_dimensionality)), + ) + ) gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score="raise") gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y) @@ -107,6 +114,7 @@ def test_sklearn_transformer_pipeline_cv_illegal_combination( ): X, y, model, _, _ = poissonGLM_model_instantiation bas = TransformerBasis(bas_cls(5)) + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) pipe = pipeline.Pipeline([("transformerbasis", bas), ("fit", model)]) param_grid = dict( transformerbasis___basis=(bas_cls(5), bas_cls(10), bas_cls(20)), @@ -165,6 +173,7 @@ def test_sklearn_transformer_pipeline_pynapple( ep = nap.IntervalSet(start=[0, 20.5], end=[20, X.shape[0]]) X_nap = nap.TsdFrame(t=np.arange(X.shape[0]), d=X, time_support=ep) y_nap = nap.Tsd(t=np.arange(X.shape[0]), d=y, time_support=ep) + bas = bas.set_input_shape(*([1] * bas._n_input_dimensionality)) bas = TransformerBasis(bas) # fit a pipeline & predict from pynapple pipe = pipeline.Pipeline([("eval", bas), ("fit", model)]) From 47174ec64fb5fded1ba2c76d336d3b475becfe7f Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 13 Dec 2024 00:01:21 -0500 Subject: [PATCH 04/41] removed par from docstrings --- src/nemos/basis/_basis.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index adaa2c8a..46d9f5a5 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -111,8 +111,6 @@ class Basis(Base, abc.ABC, BasisTransformerMixin): Parameters ---------- - n_basis_funcs : - The number of basis functions. mode : The mode of operation. 'eval' for evaluation at sample points, 'conv' for convolutional operation. From fd95bcc4caa97ce448544f767c12099cd29e0853 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 13 Dec 2024 10:25:12 -0500 Subject: [PATCH 05/41] improved modularity of sklearn clone --- src/nemos/basis/_basis.py | 25 +-------- src/nemos/basis/_basis_mixin.py | 96 +++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 24 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index 46d9f5a5..51538a06 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -625,8 +625,7 @@ def _get_feature_slicing( _get_default_slicing : Handles default slicing logic. _merge_slicing_dicts : Merges multiple slicing dictionaries, handling keys conflicts. """ - # Set default values for n_inputs and start_slice if not provided - n_inputs = n_inputs or self._n_basis_input_ + # Set default values for start_slice if not provided start_slice = start_slice or 0 # Handle the default case for non-additive basis types # See overwritten method for recursion logic @@ -816,28 +815,6 @@ class is accidentally removed. ) return [self] - def __sklearn_clone__(self) -> Basis: - """Clone the basis while preserving attributes related to input shapes. - - This method ensures that input shape attributes (e.g., `_n_basis_input_`, - `_input_shape_`) are preserved during cloning. Reinitializing the class - as in the regular sklearn clone would drop these attributes, rendering - cross-validation unusable. - The method also handles recursive cloning for composite basis structures. - """ - # clone recursively - if hasattr(self, "_basis1") and hasattr(self, "_basis2"): - basis1 = self._basis1.__sklearn_clone__() - basis2 = self._basis2.__sklearn_clone__() - klass = self.__class__(basis1, basis2) - - else: - klass = self.__class__(**self.get_params()) - - for attr_name in ["_n_basis_input_", "_input_shape_"]: - setattr(klass, attr_name, getattr(self, attr_name)) - return klass - class AdditiveBasis(CompositeBasisMixin, Basis): """ diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 9b208e31..e7413afd 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -6,6 +6,7 @@ import copy import inspect import warnings +from functools import wraps from typing import TYPE_CHECKING, Optional, Tuple, Union import numpy as np @@ -19,6 +20,52 @@ from ._basis import Basis +def set_input_shape_state(method): + """ + Decorator to preserve input shape-related attributes during method execution. + + This decorator ensures that the attributes `_n_basis_input_` and `_input_shape_` + are copied from the original object (`self`) to the returned object (`klass`) + after the wrapped method executes. It is intended to be used with methods that + clone or create a new instance of the class, ensuring these critical attributes + are retained for functionality such as cross-validation. + + Parameters + ---------- + method : + The method to be wrapped. This method is expected to return an object + (`klass`) that requires the `_n_basis_input_` and `_input_shape_` attributes. + + Returns + ------- + : + The wrapped method that copies `_n_basis_input_` and `_input_shape_` from + the original object (`self`) to the new object (`klass`). + + Examples + -------- + Applying the decorator to a method: + + >>> from functools import wraps + >>> @set_input_shape_state + ... def __sklearn_clone__(self): + ... klass = self.__class__(**self.get_params()) + ... return klass + + The `_n_basis_input_` and `_input_shape_` attributes of `self` will be + copied to `klass` after the method executes. + """ + + @wraps(method) + def wrapper(self, *args, **kwargs): + klass: Basis = method(self, *args, **kwargs) + for attr_name in ["_n_basis_input_", "_input_shape_"]: + setattr(klass, attr_name, getattr(self, attr_name)) + return klass + + return wrapper + + class EvalBasisMixin: """Mixin class for evaluational basis.""" @@ -118,6 +165,21 @@ def bounds(self, values: Union[None, Tuple[float, float]]): f"Invalid bound {values}. Lower bound is greater or equal than the upper bound." ) + @set_input_shape_state + def __sklearn_clone__(self) -> Basis: + """Clone the basis while preserving attributes related to input shapes. + + This method ensures that input shape attributes (e.g., `_n_basis_input_`, + `_input_shape_`) are preserved during cloning. Reinitializing the class + as in the regular sklearn clone would drop these attributes, rendering + cross-validation unusable. + """ + klass = self.__class__(**self.get_params()) + + # for attr_name in ["_n_basis_input_", "_input_shape_"]: + # setattr(klass, attr_name, getattr(self, attr_name)) + return klass + class ConvBasisMixin: """Mixin class for convolutional basis.""" @@ -309,6 +371,21 @@ def _check_has_kernel(self) -> None: "You must call `_set_kernel` before `_compute_features` for Conv basis." ) + @set_input_shape_state + def __sklearn_clone__(self) -> Basis: + """Clone the basis while preserving attributes related to input shapes. + + This method ensures that input shape attributes (e.g., `_n_basis_input_`, + `_input_shape_`) are preserved during cloning. Reinitializing the class + as in the regular sklearn clone would drop these attributes, rendering + cross-validation unusable. + """ + klass = self.__class__(**self.get_params()) + + for attr_name in ["_n_basis_input_", "_input_shape_"]: + setattr(klass, attr_name, getattr(self, attr_name)) + return klass + class BasisTransformerMixin: """Mixin class for constructing a transformer.""" @@ -462,3 +539,22 @@ def _list_components(self): A list with all 1d basis components. """ return self._basis1._list_components() + self._basis2._list_components() + + @set_input_shape_state + def __sklearn_clone__(self) -> Basis: + """Clone the basis while preserving attributes related to input shapes. + + This method ensures that input shape attributes (e.g., `_n_basis_input_`, + `_input_shape_`) are preserved during cloning. Reinitializing the class + as in the regular sklearn clone would drop these attributes, rendering + cross-validation unusable. + The method also handles recursive cloning for composite basis structures. + """ + # clone recursively + basis1 = self._basis1.__sklearn_clone__() + basis2 = self._basis2.__sklearn_clone__() + klass = self.__class__(basis1, basis2) + + for attr_name in ["_n_basis_input_", "_input_shape_"]: + setattr(klass, attr_name, getattr(self, attr_name)) + return klass From 096b9bdd58c472dabca084e4594386c585f68ae4 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 13 Dec 2024 11:39:35 -0500 Subject: [PATCH 06/41] added test for list component --- src/nemos/basis/_basis.py | 211 ++++------------------- src/nemos/basis/_basis_mixin.py | 166 ++++++++++++++---- src/nemos/basis/_decaying_exponential.py | 7 +- src/nemos/basis/_raised_cosine_basis.py | 10 +- src/nemos/basis/_spline_basis.py | 21 ++- src/nemos/basis/basis.py | 93 ++++------ tests/test_basis.py | 41 +++++ 7 files changed, 273 insertions(+), 276 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index 51538a06..d1a72dbc 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -146,8 +146,6 @@ def __init__( else: self._label = str(label) - self._check_n_basis_min() - # specified only after inputs/input shapes are provided self._n_basis_input_ = getattr(self, "_n_basis_input_", None) self._input_shape_ = getattr(self, "_input_shape_", None) @@ -278,12 +276,13 @@ def _set_input_independent_states(self): """ Compute all the basis states that do not depend on the input. - An example of such state is the kernel_ for Conv baisis, which can be computed + An example of such state is the kernel_ for Conv bases, which can be computed without any input (it only depends on the basis type, the window size and the number of basis elements). """ pass + @abc.abstractmethod def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ Set the expected input shape for the basis object. @@ -293,54 +292,8 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): from an array. The method also calculates the total number of input features and output features based on the number of basis functions. - Parameters - ---------- - xi : - The input shape specification. - - An integer: Represents the dimensionality of the input. A value of ``1`` is treated as scalar input. - - A tuple: Represents the exact input shape excluding the first axis (sample axis). - All elements must be integers. - - An array: The shape is extracted, excluding the first axis (assumed to be the sample axis). - - Raises - ------ - ValueError - If a tuple is provided and it contains non-integer elements. - - Returns - ------- - self : - Returns the instance itself to allow method chaining. - - Notes - ----- - All state attributes that depends on the input must be set in this method in order for - the API of basis to work correctly. In particular, this method is called by ``_basis_fit``, - which is equivalent to ``fit`` for a transformer. If any input dependent state - is not set in this method, then ``compute_features`` (equivalent to ``fit_transform``) will break. - - Separating states related to the input (settable with this method) and states that are unrelated - from the input (settable with ``set_kernel`` for Conv bases) is a deliberate design choice - that improves modularity. - """ - if isinstance(xi, tuple): - if not all(isinstance(i, int) for i in xi): - raise ValueError( - f"The tuple provided contains non integer values. Tuple: {xi}." - ) - shape = xi - elif isinstance(xi, int): - shape = () if xi == 1 else (xi,) - else: - shape = xi.shape[1:] - - n_inputs = (int(np.prod(shape)),) - - self._input_shape_ = shape - - self._n_basis_input_ = n_inputs - return self + pass @abc.abstractmethod def _evaluate(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: @@ -510,20 +463,6 @@ def _check_samples_consistency(*xi: NDArray) -> None: "Sample size mismatch. Input elements have inconsistent sample sizes." ) - @abc.abstractmethod - def _check_n_basis_min(self) -> None: - """Check that the user required enough basis elements. - - Most of the basis work with at least 1 element, but some - such as the RaisedCosineBasisLog requires a minimum of 2 basis to be well defined. - - Raises - ------ - ValueError - If an insufficient number of basis element is requested for the basis type - """ - pass - def __add__(self, other: Basis) -> AdditiveBasis: """ Add two Basis objects together. @@ -792,29 +731,6 @@ def _check_input_shape_consistency(self, x: NDArray): "different shape, please create a new basis instance." ) - def _list_components(self): - """List all basis components. - - This is re-implemented for composite basis in the mixin class. - - Returns - ------- - A list with all 1d basis components. - - Raises - ------ - RuntimeError - If the basis has multiple components. This would only happen if there is an - implementation issue, for example, if a composite basis is implemented but the - mixin class is not initialized, or if the _list_components method of the composite mixin - class is accidentally removed. - """ - if hasattr(self, "basis1"): - raise RuntimeError( - "Composite basis must implement the _list_components method." - ) - return [self] - class AdditiveBasis(CompositeBasisMixin, Basis): """ @@ -872,34 +788,9 @@ def n_output_features(self): return None return out1 + out2 + @add_docstring("set_input_shape", CompositeBasisMixin) def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis: """ - Set the expected input shape for the basis object. - - This method sets the input shape for each component basis in the ``AdditiveBasis``. - One ``xi`` must be provided for each basis component, specified as an integer, - a tuple of integers, or an array. The method calculates and stores the total number of output features - based on the number of basis functions in each component and the provided input shapes. - - Parameters - ---------- - *xi : - The input shape specifications. For every k, ``xi[k]`` can be: - - An integer: Represents the dimensionality of the input. A value of ``1`` is treated as scalar input. - - A tuple: Represents the exact input shape excluding the first axis (sample axis). - All elements must be integers. - - An array: The shape is extracted, excluding the first axis (assumed to be the sample axis). - - Raises - ------ - ValueError - If a tuple is provided and it contains non-integer elements. - - Returns - ------- - self : - Returns the instance itself to allow method chaining. - Examples -------- >>> # Generate sample data @@ -921,15 +812,7 @@ def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis: 181 """ - self._n_basis_input_ = ( - *self._basis1.set_input_shape( - *xi[: self._basis1._n_input_dimensionality] - )._n_basis_input_, - *self._basis2.set_input_shape( - *xi[self._basis1._n_input_dimensionality :] - )._n_basis_input_, - ) - return self + return super().set_input_shape(*xi) @support_pynapple(conv_type="numpy") @check_transform_input @@ -1383,64 +1266,6 @@ def _compute_features( ) return X - def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis: - """ - Set the expected input shape for the basis object. - - This method sets the input shape for each component basis in the ``MultiplicativeBasis``. - One ``xi`` must be provided for each basis component, specified as an integer, - a tuple of integers, or an array. The method calculates and stores the total number of output features - based on the number of basis functions in each component and the provided input shapes. - - Parameters - ---------- - *xi : - The input shape specifications. For every k,``xi[k]`` can be: - - An integer: Represents the dimensionality of the input. A value of ``1`` is treated as scalar input. - - A tuple: Represents the exact input shape excluding the first axis (sample axis). - All elements must be integers. - - An array: The shape is extracted, excluding the first axis (assumed to be the sample axis). - - Raises - ------ - ValueError - If a tuple is provided and it contains non-integer elements. - - Returns - ------- - self : - Returns the instance itself to allow method chaining. - Examples - -------- - >>> # Generate sample data - >>> import numpy as np - >>> import nemos as nmo - - >>> # define an additive basis - >>> basis_1 = nmo.basis.BSplineEval(5) - >>> basis_2 = nmo.basis.RaisedCosineLinearEval(6) - >>> basis_3 = nmo.basis.MSplineEval(7) - >>> multiplicative_basis = basis_1 * basis_2 * basis_3 - - Specify the input shape using all 3 allowed ways: integer, tuple, array - >>> _ = multiplicative_basis.set_input_shape(1, (2, 3), np.ones((10, 4, 5))) - - Expected output features are: - (5 * 6 * 7 bases) * (1 * 6 * 20 inputs) = 25200 - >>> multiplicative_basis.n_output_features - 25200 - - """ - self._n_basis_input_ = ( - *self._basis1.set_input_shape( - *xi[: self._basis1._n_input_dimensionality] - )._n_basis_input_, - *self._basis2.set_input_shape( - *xi[self._basis1._n_input_dimensionality :] - )._n_basis_input_, - ) - return self - def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: """Evaluate the basis set on a grid of equi-spaced sample points. @@ -1535,3 +1360,29 @@ def split_by_feature( """ return super().split_by_feature(x, axis=axis) + + @add_docstring("set_input_shape", CompositeBasisMixin) + def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis: + """ + Examples + -------- + >>> # Generate sample data + >>> import numpy as np + >>> import nemos as nmo + + >>> # define an additive basis + >>> basis_1 = nmo.basis.BSplineEval(5) + >>> basis_2 = nmo.basis.RaisedCosineLinearEval(6) + >>> basis_3 = nmo.basis.MSplineEval(7) + >>> multiplicative_basis = basis_1 * basis_2 * basis_3 + + Specify the input shape using all 3 allowed ways: integer, tuple, array + >>> _ = multiplicative_basis.set_input_shape(1, (2, 3), np.ones((10, 4, 5))) + + Expected output features are: + (5 * 6 * 7 bases) * (1 * 6 * 20 inputs) = 25200 + >>> multiplicative_basis.n_output_features + 25200 + + """ + return super().set_input_shape(*xi) diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index e7413afd..731c270f 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -66,14 +66,102 @@ def wrapper(self, *args, **kwargs): return wrapper +class AtomicBasisMixin: + + def __init__(self, n_basis_funcs: int): + self._n_basis_funcs = n_basis_funcs + self._check_n_basis_min() + + @set_input_shape_state + def __sklearn_clone__(self) -> Basis: + """Clone the basis while preserving attributes related to input shapes. + + This method ensures that input shape attributes (e.g., `_n_basis_input_`, + `_input_shape_`) are preserved during cloning. Reinitializing the class + as in the regular sklearn clone would drop these attributes, rendering + cross-validation unusable. + """ + klass = self.__class__(**self.get_params()) + + for attr_name in ["_n_basis_input_", "_input_shape_"]: + setattr(klass, attr_name, getattr(self, attr_name)) + return klass + + + def _list_components(self): + """List all basis components. + + For atomic bases, the list is just [self]. + + Returns + ------- + A list with the basis components. + + """ + return [self] + + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): + """ + Set the expected input shape for the basis object. + + This method configures the shape of the input data that the basis object expects. + ``xi`` can be specified as an integer, a tuple of integers, or derived + from an array. The method also calculates the total number of input + features and output features based on the number of basis functions. + + Parameters + ---------- + xi : + The input shape specification. + - An integer: Represents the dimensionality of the input. A value of ``1`` is treated as scalar input. + - A tuple: Represents the exact input shape excluding the first axis (sample axis). + All elements must be integers. + - An array: The shape is extracted, excluding the first axis (assumed to be the sample axis). + + Raises + ------ + ValueError + If a tuple is provided and it contains non-integer elements. + + Returns + ------- + self : + Returns the instance itself to allow method chaining. + + Notes + ----- + All state attributes that depends on the input must be set in this method in order for + the API of basis to work correctly. In particular, this method is called by ``setup_basis``, + which is equivalent to ``fit`` for a transformer. If any input dependent state + is not set in this method, then ``compute_features`` (equivalent to ``fit_transform``) will break. + + """ + if isinstance(xi, tuple): + if not all(isinstance(i, int) for i in xi): + raise ValueError( + f"The tuple provided contains non integer values. Tuple: {xi}." + ) + shape = xi + elif isinstance(xi, int): + shape = () if xi == 1 else (xi,) + else: + shape = xi.shape[1:] + + n_inputs = (int(np.prod(shape)),) + + self._input_shape_ = shape + + self._n_basis_input_ = n_inputs + return self + + class EvalBasisMixin: """Mixin class for evaluational basis.""" def __init__( - self, n_basis_funcs: int, bounds: Optional[Tuple[float, float]] = None + self, bounds: Optional[Tuple[float, float]] = None ): self.bounds = bounds - self._n_basis_funcs = n_basis_funcs def _compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor): """Evaluate basis at sample points. @@ -165,32 +253,17 @@ def bounds(self, values: Union[None, Tuple[float, float]]): f"Invalid bound {values}. Lower bound is greater or equal than the upper bound." ) - @set_input_shape_state - def __sklearn_clone__(self) -> Basis: - """Clone the basis while preserving attributes related to input shapes. - - This method ensures that input shape attributes (e.g., `_n_basis_input_`, - `_input_shape_`) are preserved during cloning. Reinitializing the class - as in the regular sklearn clone would drop these attributes, rendering - cross-validation unusable. - """ - klass = self.__class__(**self.get_params()) - - # for attr_name in ["_n_basis_input_", "_input_shape_"]: - # setattr(klass, attr_name, getattr(self, attr_name)) - return klass - class ConvBasisMixin: """Mixin class for convolutional basis.""" def __init__( - self, n_basis_funcs: int, window_size: int, conv_kwargs: Optional[dict] = None + self, window_size: int, conv_kwargs: Optional[dict] = None ): self.kernel_ = None self.window_size = window_size self.conv_kwargs = {} if conv_kwargs is None else conv_kwargs - self._n_basis_funcs = n_basis_funcs + def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor): """Convolve basis functions with input time series. @@ -371,20 +444,6 @@ def _check_has_kernel(self) -> None: "You must call `_set_kernel` before `_compute_features` for Conv basis." ) - @set_input_shape_state - def __sklearn_clone__(self) -> Basis: - """Clone the basis while preserving attributes related to input shapes. - - This method ensures that input shape attributes (e.g., `_n_basis_input_`, - `_input_shape_`) are preserved during cloning. Reinitializing the class - as in the regular sklearn clone would drop these attributes, rendering - cross-validation unusable. - """ - klass = self.__class__(**self.get_params()) - - for attr_name in ["_n_basis_input_", "_input_shape_"]: - setattr(klass, attr_name, getattr(self, attr_name)) - return klass class BasisTransformerMixin: @@ -460,9 +519,6 @@ def n_basis_funcs(self): """Read only property for composite bases.""" pass - def _check_n_basis_min(self) -> None: - pass - def setup_basis(self, *xi: NDArray) -> Basis: """ Set all basis states. @@ -558,3 +614,41 @@ def __sklearn_clone__(self) -> Basis: for attr_name in ["_n_basis_input_", "_input_shape_"]: setattr(klass, attr_name, getattr(self, attr_name)) return klass + + def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis: + """ + Set the expected input shape for the basis object. + + This method sets the input shape for each component basis in the basis. + One ``xi`` must be provided for each basis component, specified as an integer, + a tuple of integers, or an array. The method calculates and stores the total number of output features + based on the number of basis functions in each component and the provided input shapes. + + Parameters + ---------- + *xi : + The input shape specifications. For every k,``xi[k]`` can be: + - An integer: Represents the dimensionality of the input. A value of ``1`` is treated as scalar input. + - A tuple: Represents the exact input shape excluding the first axis (sample axis). + All elements must be integers. + - An array: The shape is extracted, excluding the first axis (assumed to be the sample axis). + + Raises + ------ + ValueError + If a tuple is provided and it contains non-integer elements. + + Returns + ------- + self : + Returns the instance itself to allow method chaining. + """ + self._n_basis_input_ = ( + *self._basis1.set_input_shape( + *xi[: self._basis1._n_input_dimensionality] + )._n_basis_input_, + *self._basis2.set_input_shape( + *xi[self._basis1._n_input_dimensionality :] + )._n_basis_input_, + ) + return self diff --git a/src/nemos/basis/_decaying_exponential.py b/src/nemos/basis/_decaying_exponential.py index 7df05947..5f80df58 100644 --- a/src/nemos/basis/_decaying_exponential.py +++ b/src/nemos/basis/_decaying_exponential.py @@ -14,13 +14,16 @@ from ..type_casting import support_pynapple from ..typing import FeatureMatrix from ._basis import Basis, check_transform_input, min_max_rescale_samples +from ._basis_mixin import AtomicBasisMixin -class OrthExponentialBasis(Basis, abc.ABC): +class OrthExponentialBasis(Basis, AtomicBasisMixin, abc.ABC): """Set of 1D basis decaying exponential functions numerically orthogonalized. Parameters ---------- + n_basis_funcs + Number of basis functions. decay_rates : Decay rates of the exponentials, shape ``(n_basis_funcs,)``. mode : @@ -33,10 +36,12 @@ class OrthExponentialBasis(Basis, abc.ABC): def __init__( self, + n_basis_funcs: int, decay_rates: NDArray[np.floating], mode="eval", label: Optional[str] = "OrthExponentialBasis", ): + AtomicBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs) super().__init__( mode=mode, label=label, diff --git a/src/nemos/basis/_raised_cosine_basis.py b/src/nemos/basis/_raised_cosine_basis.py index 0521a683..7b0a3765 100644 --- a/src/nemos/basis/_raised_cosine_basis.py +++ b/src/nemos/basis/_raised_cosine_basis.py @@ -11,9 +11,9 @@ from ..type_casting import support_pynapple from ..typing import FeatureMatrix from ._basis import Basis, check_transform_input, min_max_rescale_samples +from ._basis_mixin import AtomicBasisMixin - -class RaisedCosineBasisLinear(Basis, abc.ABC): +class RaisedCosineBasisLinear(Basis, AtomicBasisMixin, abc.ABC): """Represent linearly-spaced raised cosine basis functions. This implementation is based on the cosine bumps used by Pillow et al. [1]_ @@ -21,6 +21,8 @@ class RaisedCosineBasisLinear(Basis, abc.ABC): Parameters ---------- + n_basis_funcs : + The number of basis functions. mode : The mode of operation. 'eval' for evaluation at sample points, 'conv' for convolutional operation. @@ -40,10 +42,12 @@ class RaisedCosineBasisLinear(Basis, abc.ABC): def __init__( self, + n_basis_funcs: int, mode="eval", width: float = 2.0, label: Optional[str] = "RaisedCosineBasisLinear", ) -> None: + AtomicBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs) super().__init__( mode=mode, label=label, @@ -230,6 +234,7 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear, abc.ABC): def __init__( self, + n_basis_funcs: int, mode="eval", width: float = 2.0, time_scaling: float = None, @@ -237,6 +242,7 @@ def __init__( label: Optional[str] = "RaisedCosineBasisLog", ) -> None: super().__init__( + n_basis_funcs, mode=mode, width=width, label=label, diff --git a/src/nemos/basis/_spline_basis.py b/src/nemos/basis/_spline_basis.py index 5fc4c38e..7c54fddb 100644 --- a/src/nemos/basis/_spline_basis.py +++ b/src/nemos/basis/_spline_basis.py @@ -13,14 +13,16 @@ from ..type_casting import support_pynapple from ..typing import FeatureMatrix from ._basis import Basis, check_transform_input, min_max_rescale_samples +from ._basis_mixin import AtomicBasisMixin - -class SplineBasis(Basis, abc.ABC): +class SplineBasis(Basis, AtomicBasisMixin, abc.ABC): """ SplineBasis class inherits from the Basis class and represents spline basis functions. Parameters ---------- + n_basis_funcs : + Number of basis functions. mode : The mode of operation. 'eval' for evaluation at sample points, 'conv' for convolutional operation. @@ -38,11 +40,13 @@ class SplineBasis(Basis, abc.ABC): def __init__( self, + n_basis_funcs: int, order: int = 2, label: Optional[str] = None, mode: Literal["conv", "eval"] = "eval", ) -> None: self.order = order + AtomicBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs) super().__init__( label=label, mode=mode, @@ -154,6 +158,9 @@ class MSplineBasis(SplineBasis, abc.ABC): Parameters ---------- + n_basis_funcs : + The number of basis functions to generate. More basis functions allow for + more flexible data modeling but can lead to overfitting. mode : The mode of operation. 'eval' for evaluation at sample points, 'conv' for convolutional operation. @@ -191,11 +198,13 @@ class MSplineBasis(SplineBasis, abc.ABC): def __init__( self, + n_basis_funcs: int, mode: Literal["eval", "conv"] = "eval", order: int = 2, label: Optional[str] = "MSplineEval", ) -> None: super().__init__( + n_basis_funcs, mode=mode, order=order, label=label, @@ -292,6 +301,8 @@ class BSplineBasis(SplineBasis, abc.ABC): Parameters ---------- + n_basis_funcs : + Number of basis functions. mode : The mode of operation. ``'eval'`` for evaluation at sample points, 'conv' for convolutional operation. @@ -317,11 +328,13 @@ class BSplineBasis(SplineBasis, abc.ABC): def __init__( self, + n_basis_funcs: int, mode="eval", order: int = 4, label: Optional[str] = "BSplineBasis", ): super().__init__( + n_basis_funcs, mode=mode, order=order, label=label, @@ -406,6 +419,8 @@ class CyclicBSplineBasis(SplineBasis, abc.ABC): Parameters ---------- + n_basis_funcs : + Number of basis functions. mode : The mode of operation. 'eval' for evaluation at sample points, 'conv' for convolutional operation. @@ -427,11 +442,13 @@ class CyclicBSplineBasis(SplineBasis, abc.ABC): def __init__( self, + n_basis_funcs: int, mode="eval", order: int = 4, label: Optional[str] = "CyclicBSplineBasis", ): super().__init__( + n_basis_funcs, mode=mode, order=order, label=label, diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 702f7b56..758e1f6d 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -9,7 +9,7 @@ from ..typing import FeatureMatrix from ._basis import add_docstring -from ._basis_mixin import ConvBasisMixin, EvalBasisMixin +from ._basis_mixin import AtomicBasisMixin, ConvBasisMixin, EvalBasisMixin from ._decaying_exponential import OrthExponentialBasis from ._raised_cosine_basis import RaisedCosineBasisLinear, RaisedCosineBasisLog from ._spline_basis import BSplineBasis, CyclicBSplineBasis, MSplineBasis @@ -83,13 +83,15 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "BSplineEval", ): - EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds) + BSplineBasis.__init__( self, + n_basis_funcs, mode="eval", order=order, label=label, ) + EvalBasisMixin.__init__(self, bounds=bounds) @add_docstring("split_by_feature", BSplineBasis) def split_by_feature( @@ -182,7 +184,7 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): 100 """ - return super().set_input_shape(xi) + return AtomicBasisMixin.set_input_shape(self, xi) class BSplineConv(ConvBasisMixin, BSplineBasis): @@ -236,14 +238,10 @@ def __init__( label: Optional[str] = "BSplineConv", conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__( - self, - n_basis_funcs=n_basis_funcs, - window_size=window_size, - conv_kwargs=conv_kwargs, - ) + ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) BSplineBasis.__init__( self, + n_basis_funcs, mode="conv", order=order, label=label, @@ -340,7 +338,7 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): 100 """ - return super().set_input_shape(xi) + return AtomicBasisMixin.set_input_shape(self, xi) class CyclicBSplineEval(EvalBasisMixin, CyclicBSplineBasis): @@ -381,9 +379,10 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "CyclicBSplineEval", ): - EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds) + EvalBasisMixin.__init__(self, bounds=bounds) CyclicBSplineBasis.__init__( self, + n_basis_funcs, mode="eval", order=order, label=label, @@ -480,7 +479,7 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): 100 """ - return super().set_input_shape(xi) + return AtomicBasisMixin.set_input_shape(self, xi) class CyclicBSplineConv(ConvBasisMixin, CyclicBSplineBasis): @@ -526,14 +525,10 @@ def __init__( label: Optional[str] = "CyclicBSplineConv", conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__( - self, - n_basis_funcs=n_basis_funcs, - window_size=window_size, - conv_kwargs=conv_kwargs, - ) + ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) CyclicBSplineBasis.__init__( self, + n_basis_funcs, mode="conv", order=order, label=label, @@ -630,7 +625,7 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): 100 """ - return super().set_input_shape(xi) + return AtomicBasisMixin.set_input_shape(self, xi) class MSplineEval(EvalBasisMixin, MSplineBasis): @@ -695,9 +690,10 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "MSplineEval", ): - EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds) + EvalBasisMixin.__init__(self, bounds=bounds) MSplineBasis.__init__( self, + n_basis_funcs, mode="eval", order=order, label=label, @@ -794,7 +790,7 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): 100 """ - return super().set_input_shape(xi) + return AtomicBasisMixin.set_input_shape(self, xi) class MSplineConv(ConvBasisMixin, MSplineBasis): @@ -864,14 +860,10 @@ def __init__( label: Optional[str] = "MSplineConv", conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__( - self, - n_basis_funcs=n_basis_funcs, - window_size=window_size, - conv_kwargs=conv_kwargs, - ) + ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) MSplineBasis.__init__( self, + n_basis_funcs, mode="conv", order=order, label=label, @@ -968,7 +960,7 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): 100 """ - return super().set_input_shape(xi) + return AtomicBasisMixin.set_input_shape(self, xi) class RaisedCosineLinearEval(EvalBasisMixin, RaisedCosineBasisLinear): @@ -1017,9 +1009,10 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "RaisedCosineLinearEval", ): - EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds) + EvalBasisMixin.__init__(self, bounds=bounds) RaisedCosineBasisLinear.__init__( self, + n_basis_funcs, width=width, mode="eval", label=label, @@ -1109,7 +1102,7 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): 100 """ - return super().set_input_shape(xi) + return AtomicBasisMixin.set_input_shape(self, xi) class RaisedCosineLinearConv(ConvBasisMixin, RaisedCosineBasisLinear): @@ -1163,14 +1156,10 @@ def __init__( label: Optional[str] = "RaisedCosineLinearConv", conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__( - self, - n_basis_funcs=n_basis_funcs, - window_size=window_size, - conv_kwargs=conv_kwargs, - ) + ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) RaisedCosineBasisLinear.__init__( self, + n_basis_funcs, mode="conv", width=width, label=label, @@ -1260,7 +1249,7 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): 100 """ - return super().set_input_shape(xi) + return AtomicBasisMixin.set_input_shape(self, xi) class RaisedCosineLogEval(EvalBasisMixin, RaisedCosineBasisLog): @@ -1323,9 +1312,10 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "RaisedCosineLogEval", ): - EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds) + EvalBasisMixin.__init__(self, bounds=bounds) RaisedCosineBasisLog.__init__( self, + n_basis_funcs, width=width, time_scaling=time_scaling, enforce_decay_to_zero=enforce_decay_to_zero, @@ -1417,7 +1407,7 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): 100 """ - return super().set_input_shape(xi) + return AtomicBasisMixin.set_input_shape(self, xi) class RaisedCosineLogConv(ConvBasisMixin, RaisedCosineBasisLog): @@ -1481,14 +1471,10 @@ def __init__( label: Optional[str] = "RaisedCosineLogConv", conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__( - self, - n_basis_funcs=n_basis_funcs, - window_size=window_size, - conv_kwargs=conv_kwargs, - ) + ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) RaisedCosineBasisLog.__init__( self, + n_basis_funcs, mode="conv", width=width, time_scaling=time_scaling, @@ -1580,7 +1566,7 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): 100 """ - return super().set_input_shape(xi) + return AtomicBasisMixin.set_input_shape(self, xi) class OrthExponentialEval(EvalBasisMixin, OrthExponentialBasis): @@ -1623,9 +1609,10 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "OrthExponentialEval", ): - EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds) + EvalBasisMixin.__init__(self, bounds=bounds) OrthExponentialBasis.__init__( self, + n_basis_funcs, decay_rates=decay_rates, mode="eval", label=label, @@ -1719,7 +1706,7 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): 100 """ - return super().set_input_shape(xi) + return AtomicBasisMixin.set_input_shape(self, xi) class OrthExponentialConv(ConvBasisMixin, OrthExponentialBasis): @@ -1765,14 +1752,10 @@ def __init__( label: Optional[str] = "OrthExponentialConv", conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__( - self, - n_basis_funcs=n_basis_funcs, - window_size=window_size, - conv_kwargs=conv_kwargs, - ) + ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) OrthExponentialBasis.__init__( self, + n_basis_funcs, mode="conv", decay_rates=decay_rates, label=label, @@ -1870,7 +1853,7 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): 100 """ - return super().set_input_shape(xi) + return AtomicBasisMixin.set_input_shape(self, xi) def _check_window_size(self, window_size: int): """OrthExponentialBasis specific window size check.""" diff --git a/tests/test_basis.py b/tests/test_basis.py index 86c58fa2..770bce35 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -1328,6 +1328,20 @@ def test_set_input_value_types(self, inp_shape, expectation, cls): with expectation: bas.set_input_shape(inp_shape) + @pytest.mark.parametrize( + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 6})] + ) + def test_list_component(self, mode, kwargs, cls): + basis_obj = cls[mode]( + n_basis_funcs=5, + **kwargs, + **extra_decay_rates(cls[mode], 5), + ) + + out = basis_obj._list_components() + assert len(out) == 1 + assert id(out[0]) == id(basis_obj) + class TestRaisedCosineLogBasis(BasisFuncsTesting): cls = {"eval": basis.RaisedCosineLogEval, "conv": basis.RaisedCosineLogConv} @@ -1958,6 +1972,33 @@ def test_samples_range_matches_compute_features_requirements( class TestAdditiveBasis(CombinedBasis): cls = {"eval": AdditiveBasis, "conv": AdditiveBasis} + @pytest.mark.parametrize("basis_a", list_all_basis_classes()) + @pytest.mark.parametrize("basis_b", list_all_basis_classes()) + def test_list_component(self, basis_a, basis_b, basis_class_specific_params): + basis_a_obj = self.instantiate_basis( + 5, basis_a, basis_class_specific_params, window_size=10 + ) + basis_b_obj = self.instantiate_basis( + 6, basis_b, basis_class_specific_params, window_size=10 + ) + add = basis_a_obj + basis_b_obj + out = add._list_components() + + assert len(out) == add._n_input_dimensionality + + def get_ids(bas): + + if hasattr(bas, "basis1"): + ids = get_ids(bas.basis1) + ids += get_ids(bas.basis2) + else: + ids = [id(bas)] + return ids + + id_list = get_ids(add) + + assert tuple(id(o) for o in out) == tuple(id_list) + @pytest.mark.parametrize("samples", [[[0], []], [[], [0]], [[0, 0], [0, 0]]]) @pytest.mark.parametrize("base_cls", [basis.BSplineEval, basis.BSplineConv]) def test_non_empty_samples(self, base_cls, samples, basis_class_specific_params): From dc6bd9e2c304837076a37e6ba58fc8efeddfa938 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 13 Dec 2024 11:40:48 -0500 Subject: [PATCH 07/41] linted --- src/nemos/basis/_basis_mixin.py | 11 ++--------- src/nemos/basis/_raised_cosine_basis.py | 1 + src/nemos/basis/_spline_basis.py | 1 + 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 731c270f..2bf33633 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -87,7 +87,6 @@ def __sklearn_clone__(self) -> Basis: setattr(klass, attr_name, getattr(self, attr_name)) return klass - def _list_components(self): """List all basis components. @@ -158,9 +157,7 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): class EvalBasisMixin: """Mixin class for evaluational basis.""" - def __init__( - self, bounds: Optional[Tuple[float, float]] = None - ): + def __init__(self, bounds: Optional[Tuple[float, float]] = None): self.bounds = bounds def _compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor): @@ -257,14 +254,11 @@ def bounds(self, values: Union[None, Tuple[float, float]]): class ConvBasisMixin: """Mixin class for convolutional basis.""" - def __init__( - self, window_size: int, conv_kwargs: Optional[dict] = None - ): + def __init__(self, window_size: int, conv_kwargs: Optional[dict] = None): self.kernel_ = None self.window_size = window_size self.conv_kwargs = {} if conv_kwargs is None else conv_kwargs - def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor): """Convolve basis functions with input time series. @@ -445,7 +439,6 @@ def _check_has_kernel(self) -> None: ) - class BasisTransformerMixin: """Mixin class for constructing a transformer.""" diff --git a/src/nemos/basis/_raised_cosine_basis.py b/src/nemos/basis/_raised_cosine_basis.py index 7b0a3765..dbf039eb 100644 --- a/src/nemos/basis/_raised_cosine_basis.py +++ b/src/nemos/basis/_raised_cosine_basis.py @@ -13,6 +13,7 @@ from ._basis import Basis, check_transform_input, min_max_rescale_samples from ._basis_mixin import AtomicBasisMixin + class RaisedCosineBasisLinear(Basis, AtomicBasisMixin, abc.ABC): """Represent linearly-spaced raised cosine basis functions. diff --git a/src/nemos/basis/_spline_basis.py b/src/nemos/basis/_spline_basis.py index 7c54fddb..c8f42d90 100644 --- a/src/nemos/basis/_spline_basis.py +++ b/src/nemos/basis/_spline_basis.py @@ -15,6 +15,7 @@ from ._basis import Basis, check_transform_input, min_max_rescale_samples from ._basis_mixin import AtomicBasisMixin + class SplineBasis(Basis, AtomicBasisMixin, abc.ABC): """ SplineBasis class inherits from the Basis class and represents spline basis functions. From 7471d3ca94005b521b0deb4c8c313d562317c6ec Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 13 Dec 2024 11:46:19 -0500 Subject: [PATCH 08/41] linted --- pyproject.toml | 2 +- src/nemos/basis/_basis_mixin.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d20fd307..7dd81daa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,7 +126,7 @@ testpaths = ["tests"] # Specify the directory where test files are l filterwarnings = [ # note the use of single quote below to denote "raw" strings in TOML 'ignore:plotting functions contained within:UserWarning', - 'ignore:Tolerance of \d\.\d+e-\d\d reached:RuntimeWarning', + 'ignore:Tolerance of -?\d\.\d+e-\d\d reached:RuntimeWarning', ] [tool.coverage.run] diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 2bf33633..d1178205 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -67,6 +67,7 @@ def wrapper(self, *args, **kwargs): class AtomicBasisMixin: + """Mixin class for atomic bases (i.e. non-composite).""" def __init__(self, n_basis_funcs: int): self._n_basis_funcs = n_basis_funcs From f85be92d0680dbab46583a2ca41a44e5ff0b4e90 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 13 Dec 2024 11:50:39 -0500 Subject: [PATCH 09/41] fixed warn --- tests/test_identifiability_constraints.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_identifiability_constraints.py b/tests/test_identifiability_constraints.py index ca4f4be2..3320faf7 100644 --- a/tests/test_identifiability_constraints.py +++ b/tests/test_identifiability_constraints.py @@ -190,6 +190,7 @@ def test_feature_matrix_dtype(dtype, expected_dtype): ) def test_apply_constraint_with_invalid(invalid_entries): """Test if the matrix retains its dtype after applying constraints.""" + jax.config.update('jax_enable_x64', True) x = np.random.randn(10, 5) # add invalid x[:2, 2] = invalid_entries From 59d8657c31faffd42fa2839abd4c509a63350b30 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 13 Dec 2024 11:54:28 -0500 Subject: [PATCH 10/41] comments on ignore --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 7dd81daa..f57d3238 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,8 +125,12 @@ profile = "black" testpaths = ["tests"] # Specify the directory where test files are located filterwarnings = [ # note the use of single quote below to denote "raw" strings in TOML + # this is raised whenever one imports the plotting utils 'ignore:plotting functions contained within:UserWarning', + # numerical inversion test reaches tolerance... 'ignore:Tolerance of -?\d\.\d+e-\d\d reached:RuntimeWarning', + # mpl must be non-interctive for testing otherwise doctests will freeze + 'ignore:FigureCanvasAgg is non-interactive, and thus cannot be shown:UserWarning', ] [tool.coverage.run] From 4c5e3b70424767c301427bd0816b7c577ee5b44f Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 13 Dec 2024 12:21:57 -0500 Subject: [PATCH 11/41] fix warns --- src/nemos/identifiability_constraints.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/nemos/identifiability_constraints.py b/src/nemos/identifiability_constraints.py index b949b489..d0f7709e 100644 --- a/src/nemos/identifiability_constraints.py +++ b/src/nemos/identifiability_constraints.py @@ -218,6 +218,8 @@ def apply_identifiability_constraints( >>> from nemos.identifiability_constraints import apply_identifiability_constraints >>> from nemos.basis import BSplineEval >>> from nemos.glm import GLM + >>> import jax + >>> jax.config.update('jax_enable_x64', True) >>> # define a feature matrix >>> bas = BSplineEval(5) + BSplineEval(6) >>> feature_matrix = bas.compute_features(np.random.randn(100), np.random.randn(100)) @@ -280,9 +282,11 @@ def apply_identifiability_constraints_by_basis_component( Examples -------- >>> import numpy as np + >>> import jax >>> from nemos.identifiability_constraints import apply_identifiability_constraints_by_basis_component >>> from nemos.basis import BSplineEval >>> from nemos.glm import GLM + >>> jax.config.update('jax_enable_x64', True) >>> # define a feature matrix >>> bas = BSplineEval(5) + BSplineEval(6) >>> feature_matrix = bas.compute_features(np.random.randn(100), np.random.randn(100)) From ac6323e1eeb13b57ce0710e5c18fad95f70ac563 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 13 Dec 2024 12:22:43 -0500 Subject: [PATCH 12/41] linted tests --- tests/test_identifiability_constraints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_identifiability_constraints.py b/tests/test_identifiability_constraints.py index 3320faf7..f40ad214 100644 --- a/tests/test_identifiability_constraints.py +++ b/tests/test_identifiability_constraints.py @@ -190,7 +190,7 @@ def test_feature_matrix_dtype(dtype, expected_dtype): ) def test_apply_constraint_with_invalid(invalid_entries): """Test if the matrix retains its dtype after applying constraints.""" - jax.config.update('jax_enable_x64', True) + jax.config.update("jax_enable_x64", True) x = np.random.randn(10, 5) # add invalid x[:2, 2] = invalid_entries From 50548f0ff43ca34ce12555ce6f5a7c9a64b5a120 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 13 Dec 2024 12:32:28 -0500 Subject: [PATCH 13/41] improved unpacking --- src/nemos/basis/_transformer_basis.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index 91865028..c2b5f51b 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -100,14 +100,11 @@ def _unpack_inputs(self, X: FeatureMatrix) -> List: """ n_samples = X.shape[0] - out = [] - cc = 0 - for i, bas in enumerate(self._list_components()): - n_input = self._n_basis_input_[i] - out.append( - np.reshape(X[:, cc : cc + n_input], (n_samples, *bas._input_shape_)) - ) - cc += n_input + out = [ + np.reshape(X[:, cc: cc + n_input], (n_samples, *bas._input_shape_)) + for i, (bas, n_input) in enumerate(zip(self._list_components(), self._n_basis_input_)) + for cc in [sum(self._n_basis_input_[:i])] + ] return out def fit(self, X: FeatureMatrix, y=None): From 9281a00ff11734ddab3b9cd71b500c5ef715056f Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 13 Dec 2024 12:36:00 -0500 Subject: [PATCH 14/41] lint --- src/nemos/basis/_transformer_basis.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index c2b5f51b..a049b9fd 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -101,8 +101,10 @@ def _unpack_inputs(self, X: FeatureMatrix) -> List: """ n_samples = X.shape[0] out = [ - np.reshape(X[:, cc: cc + n_input], (n_samples, *bas._input_shape_)) - for i, (bas, n_input) in enumerate(zip(self._list_components(), self._n_basis_input_)) + np.reshape(X[:, cc : cc + n_input], (n_samples, *bas._input_shape_)) + for i, (bas, n_input) in enumerate( + zip(self._list_components(), self._n_basis_input_) + ) for cc in [sum(self._n_basis_input_[:i])] ] return out From 5667dd8cb769cab4322d52662c95e7054fd6a45e Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 13 Dec 2024 17:20:56 -0500 Subject: [PATCH 15/41] start note to transformer --- .../plot_05_sklearn_pipeline_cv_demo.md | 8 +- .../how_to_guide/plot_05_transformer_basis.md | 151 ++++++++++++++++++ src/nemos/basis/_basis.py | 21 --- src/nemos/basis/_basis_mixin.py | 21 +++ 4 files changed, 176 insertions(+), 25 deletions(-) create mode 100644 docs/how_to_guide/plot_05_transformer_basis.md diff --git a/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md b/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md index 9f5a9652..461f1b31 100644 --- a/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md +++ b/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md @@ -71,7 +71,8 @@ To set up a scikit-learn [`Pipeline`](https://scikit-learn.org/1.5/modules/gener Each transformation step takes a 2D array `X` of shape `(num_samples, num_original_features)` as input and outputs another 2D array of shape `(num_samples, num_transformed_features)`. The final step takes a pair `(X, y)`, where `X` is as before, and `y` is a 1D array of shape `(n_samples,)` containing the observations to be modeled. You can define a pipeline as follows: -```python + +```{code} ipython3 from sklearn.pipeline import Pipeline # Assume transformer_i/predictor is a transformer/model object @@ -92,7 +93,7 @@ Here we used a placeholder `"label_i"` for demonstration; you should choose a mo ::: Calling `pipe.fit(X, y)` will perform the following computations: -```python +```{code} ipython3 # Chain of transformations X1 = transformer_1.fit_transform(X) X2 = transformer_2.fit_transform(X1) @@ -111,6 +112,7 @@ Pipelines not only streamline and simplify your code but also offer several othe In the following sections, we will showcase this approach with a concrete example: selecting the appropriate basis type and number of bases for a GLM regression in NeMoS. ## Combining basis transformations and GLM in a pipeline + Let's start by creating some toy data. @@ -148,7 +150,6 @@ ax.set_xlabel("input") ax.set_ylabel("spike count") sns.despine(ax=ax) ``` - ### Converting NeMoS `Basis` to a transformer In order to use NeMoS [`Basis`](nemos.basis._basis.Basis) in a pipeline, we need to convert it into a scikit-learn transformer. This can be achieved through the [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) wrapper class. @@ -182,7 +183,6 @@ bas.n_basis_funcs = 100 print(bas.n_basis_funcs, trans_bas.n_basis_funcs) ``` - ### Creating and fitting a pipeline We might want to combine first transforming the input data with our basis functions, then fitting a GLM on the transformed data. diff --git a/docs/how_to_guide/plot_05_transformer_basis.md b/docs/how_to_guide/plot_05_transformer_basis.md new file mode 100644 index 00000000..01cbd546 --- /dev/null +++ b/docs/how_to_guide/plot_05_transformer_basis.md @@ -0,0 +1,151 @@ +# Converting NeMoS Bases To scikit-learn Transformers + +## scikit-learn Transformers and NeMoS Basis + +`scikit-learn` is a great machine learning package that provides advanced tooling for creating data analysis pipelines, from input transformations to model fitting and cross-validation. + +All of `scikit-learn` machinery relies on very strong assumptions on how one should structure the inputs to each processing step. +In particular, all `scikit-learn` objects requires inputs in the form of arrays of at most two-dimensions, where the first dimension always represents time (or samples) dimension, and the other features. +This may feel a bit rigid at first, but what this buys you is that any transformation can be chained to any other, greatly simplifying the process of building stable complex pipelines. + +In `scikit-learn`, the data transformation steps are performed by object called `transformers`. + + +On the other hand, NeMoS basis are powerful feature constructors that allow a high degree of flexibility in terms of the required input structure. +Depending on the basis type, it can accept one or more input arrays or `pynapple` time series data, each of which can have any shape as long as the time (or sample) axis is the first of each array; +NeMoS design favours object composability, one can combine any two or more bases to compute complex features, and a user-friendly interface can accept a separate array/time series for each input type (e.g., an array with the spike counts, an array for the animal's position, etc.). + +Both approaches to data transformations are valuable and have their own advantages. +Wouldn't it be great if one could combine them? Well, this is what NeMoS `TransformerBasis` are for! + +:::{admonition} Select Basis Hyperparameters with skit-learn +:class: note + +If you want to learn more about basis hyperparameters with `sklearn` pipelining and cross-validation, check out [this guide](sklearn-how-to). +::: + +## From Basis to TransformerBasis + + +With NeMoS, you can easily create a basis accepting two inputs. Let's assume that we want to process the neural activity as a 2-dimensional spike count array of shape `(n_samples, n_neurons)` and a second array with the speed of an animal of shape `(n_samples,)`. + +```{code-block} ipython3 +import numpy as np +import nemos as nmo + +# create the arrays +n_samples, n_neurons = 100, 5 +counts = np.random.poisson(size=(100, 5)) +speed = np.random.normal(size=(100)) + +# create a composite basis +counts_basis = nmo.basis.RaisedCosineLogConv(5, window_size=10) +speed_basis = nmo.basis.BSplineBasis(5) +composite_basis = counts_basis + speed_basis + +# compute the features +X = composite_basis.compute_features(counts, speed) + +``` + +### Converting NeMoS `Basis` to a transformer + +Now, imagine that we want to use this basis as a step in a `scikit-learn` pipeline. +In this standard (for NeMoS) form, it would not be possible the `composite_basis` object requires two inputs. We need to convert it first into a compliant scikit-learn transformer. This can be achieved through the [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) wrapper class. + +Instantiating a [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) can be done either using by the constructor directly or with [`Basis.to_transformer()`](nemos.basis._basis.Basis.to_transformer): + + +```{code-cell} ipython3 +bas = nmo.basis.RaisedCosineLinearConv(5, window_size=5) + +# initalize using the constructor +trans_bas = nmo.basis.TransformerBasis(bas) + +# equivalent initialization via "to_transformer" +trans_bas = bas.to_transformer() + +``` + +[`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) provides convenient access to the underlying [`Basis`](nemos.basis._basis.Basis) object's attributes: + + +```{code-cell} ipython3 +print(bas.n_basis_funcs, trans_bas.n_basis_funcs) +``` + +We can also set attributes of the underlying [`Basis`](nemos.basis._basis.Basis). Note that -- because [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) is created with a copy of the [`Basis`](nemos.basis._basis.Basis) object passed to it -- this does not change the original [`Basis`](nemos.basis._basis.Basis), and neither does changing the original [`Basis`](nemos.basis._basis.Basis) change [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) we created: + + +```{code-cell} ipython3 +trans_bas.n_basis_funcs = 10 +bas.n_basis_funcs = 100 + +print(bas.n_basis_funcs, trans_bas.n_basis_funcs) +``` + +As any `sckit-learn` tansformer, the `TransformerBasis` implements `fit`, a preparation step, `transform`, the actual feature computation, and `fit_transform` which chains `fit` and `transform`. These methods comply with the `scikit-learn` input structure convention, and therefore all accepts a single 2D array. + +## Setting up the TransformerBasis + +At this point we have an object equipped with the correct methods, so now all we have to do is concatenate the inputs into a unique array and call `fit_transform`, right? + +```{code-cell} ipython3 +# reinstantiate the basis transformer for illustration porpuses +composite_basis = counts_basis + speed_basis +transformer_bas = (composite_basis).to_transformer() + +# concatenate the inputs +inp = np.concatenate([counts, speed[:, np.newaxis]]) +print(inp.shape) + +try: + transformer_bas.fit_transform(inp) +except RuntimeError as e: + print(repr(e)) + +``` + +Unfortunately not yet. The problem is that the basis has never interacted with the two separate inputs, and therefore doesn't know which columns of `inp` should be processed by `count_basis` and which by `speed_basis`. + +There are several ways in which you can provide this information to the basis. The first one is by calling the method `set_input_shape`. + +This can be called before or after the transformer basis is defined. The method extracts and store the array shapes excluding the sample axis (which won't be affected in the concatenation). + +`set_input_shape` accepts directly the inputs, + +```{code-cell} ipython3 + +composite_basis.set_input_shape(counts, speed) +out = transformer_bas.fit_transform(inp) +``` + +If the input is 1D or 2D, the number of columns, +```{code-cell} ipython3 + +trans_bas = composite_basis.set_input_shape(5, 1).transformer() +out = transformer_bas.fit_transform(inp) +``` + +A tuple containing the shapes of all axis other than the first, +```{code-cell} ipython3 + +composite_basis.set_input_shape((5,), (1,)) +out = transformer_bas.fit_transform(inp) +``` + +Or a mix of any of the above. + +```{code-cell} ipython3 + +composite_basis.set_input_shape(counts, 1) +out = transformer_bas.fit_transform(inp) +``` + +You can also invert the order and call `to_transform` first and set the input shapes after. +```{code-cell} ipython3 + +trans_bas = composite_basis.transformer() +trans_bas.set_input_shape(5, 1) +out = transformer_bas.fit_transform(inp) +``` \ No newline at end of file diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index d1a72dbc..b6a66a56 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -710,27 +710,6 @@ def is_leaf(val): ) return reshaped_out - def _check_input_shape_consistency(self, x: NDArray): - """Check input consistency across calls.""" - # remove sample axis and squeeze - shape = x.shape[1:] - - initialized = self._input_shape_ is not None - is_shape_match = self._input_shape_ == shape - if initialized and not is_shape_match: - expected_shape_str = "(n_samples, " + f"{self._input_shape_}"[1:] - expected_shape_str = expected_shape_str.replace(",)", ")") - raise ValueError( - f"Input shape mismatch detected.\n\n" - f"The basis `{self.__class__.__name__}` with label '{self.label}' expects inputs with " - f"a consistent shape (excluding the sample axis). Specifically, the shape should be:\n" - f" Expected: {expected_shape_str}\n" - f" But got: {x.shape}.\n\n" - "Note: The number of samples (`n_samples`) can vary between calls of `compute_features`, " - "but all other dimensions must remain the same. If you need to process inputs with a " - "different shape, please create a new basis instance." - ) - class AdditiveBasis(CompositeBasisMixin, Basis): """ diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index d1178205..f5635118 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -155,6 +155,27 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): return self + def _check_input_shape_consistency(self, x: NDArray): + """Check input consistency across calls.""" + # remove sample axis and squeeze + shape = x.shape[1:] + + initialized = self._input_shape_ is not None + is_shape_match = self._input_shape_ == shape + if initialized and not is_shape_match: + expected_shape_str = "(n_samples, " + f"{self._input_shape_}"[1:] + expected_shape_str = expected_shape_str.replace(",)", ")") + raise ValueError( + f"Input shape mismatch detected.\n\n" + f"The basis `{self.__class__.__name__}` with label '{self.label}' expects inputs with " + f"a consistent shape (excluding the sample axis). Specifically, the shape should be:\n" + f" Expected: {expected_shape_str}\n" + f" But got: {x.shape}.\n\n" + "Note: The number of samples (`n_samples`) can vary between calls of `compute_features`, " + "but all other dimensions must remain the same. If you need to process inputs with a " + "different shape, please create a new basis instance." + ) + class EvalBasisMixin: """Mixin class for evaluational basis.""" From 9829777460d3e66fd853d661d6c62fdf0502db00 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 13 Dec 2024 18:38:08 -0500 Subject: [PATCH 16/41] finished note and modified the logic --- .../how_to_guide/plot_05_transformer_basis.md | 65 ++++++++++++------- src/nemos/basis/_transformer_basis.py | 45 ++++++++++++- 2 files changed, 85 insertions(+), 25 deletions(-) diff --git a/docs/how_to_guide/plot_05_transformer_basis.md b/docs/how_to_guide/plot_05_transformer_basis.md index 01cbd546..7df68de5 100644 --- a/docs/how_to_guide/plot_05_transformer_basis.md +++ b/docs/how_to_guide/plot_05_transformer_basis.md @@ -18,11 +18,6 @@ NeMoS design favours object composability, one can combine any two or more bases Both approaches to data transformations are valuable and have their own advantages. Wouldn't it be great if one could combine them? Well, this is what NeMoS `TransformerBasis` are for! -:::{admonition} Select Basis Hyperparameters with skit-learn -:class: note - -If you want to learn more about basis hyperparameters with `sklearn` pipelining and cross-validation, check out [this guide](sklearn-how-to). -::: ## From Basis to TransformerBasis @@ -40,7 +35,7 @@ speed = np.random.normal(size=(100)) # create a composite basis counts_basis = nmo.basis.RaisedCosineLogConv(5, window_size=10) -speed_basis = nmo.basis.BSplineBasis(5) +speed_basis = nmo.basis.BSplineEval(5) composite_basis = counts_basis + speed_basis # compute the features @@ -51,9 +46,9 @@ X = composite_basis.compute_features(counts, speed) ### Converting NeMoS `Basis` to a transformer Now, imagine that we want to use this basis as a step in a `scikit-learn` pipeline. -In this standard (for NeMoS) form, it would not be possible the `composite_basis` object requires two inputs. We need to convert it first into a compliant scikit-learn transformer. This can be achieved through the [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) wrapper class. +In this standard (for NeMoS) form, it would not be possible the `composite_basis` object requires two inputs. We need to convert it first into a compliant scikit-learn transformer. This can be achieved through the [`TransformerBasis`](nemos.basis._trans_basis.TransformerBasis) wrapper class. -Instantiating a [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) can be done either using by the constructor directly or with [`Basis.to_transformer()`](nemos.basis._basis.Basis.to_transformer): +Instantiating a [`TransformerBasis`](nemos.basis._trans_basis.TransformerBasis) can be done either using by the constructor directly or with [`Basis.to_transformer()`](nemos.basis._basis.Basis.to_transformer): ```{code-cell} ipython3 @@ -67,14 +62,14 @@ trans_bas = bas.to_transformer() ``` -[`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) provides convenient access to the underlying [`Basis`](nemos.basis._basis.Basis) object's attributes: +[`TransformerBasis`](nemos.basis._trans_basis.TransformerBasis) provides convenient access to the underlying [`Basis`](nemos.basis._basis.Basis) object's attributes: ```{code-cell} ipython3 print(bas.n_basis_funcs, trans_bas.n_basis_funcs) ``` -We can also set attributes of the underlying [`Basis`](nemos.basis._basis.Basis). Note that -- because [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) is created with a copy of the [`Basis`](nemos.basis._basis.Basis) object passed to it -- this does not change the original [`Basis`](nemos.basis._basis.Basis), and neither does changing the original [`Basis`](nemos.basis._basis.Basis) change [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) we created: +We can also set attributes of the underlying [`Basis`](nemos.basis._basis.Basis). Note that -- because [`TransformerBasis`](nemos.basis._trans_basis.TransformerBasis) is created with a copy of the [`Basis`](nemos.basis._basis.Basis) object passed to it -- this does not change the original [`Basis`](nemos.basis._basis.Basis), and neither does changing the original [`Basis`](nemos.basis._basis.Basis) change [`TransformerBasis`](nemos.basis._trans_basis.TransformerBasis) we created: ```{code-cell} ipython3 @@ -91,19 +86,19 @@ As any `sckit-learn` tansformer, the `TransformerBasis` implements `fit`, a prep At this point we have an object equipped with the correct methods, so now all we have to do is concatenate the inputs into a unique array and call `fit_transform`, right? ```{code-cell} ipython3 + # reinstantiate the basis transformer for illustration porpuses composite_basis = counts_basis + speed_basis -transformer_bas = (composite_basis).to_transformer() +trans_bas = (composite_basis).to_transformer() # concatenate the inputs -inp = np.concatenate([counts, speed[:, np.newaxis]]) +inp = np.concatenate([counts, speed[:, np.newaxis]], axis=1) print(inp.shape) try: - transformer_bas.fit_transform(inp) + trans_bas.fit_transform(inp) except RuntimeError as e: print(repr(e)) - ``` Unfortunately not yet. The problem is that the basis has never interacted with the two separate inputs, and therefore doesn't know which columns of `inp` should be processed by `count_basis` and which by `speed_basis`. @@ -117,35 +112,59 @@ This can be called before or after the transformer basis is defined. The method ```{code-cell} ipython3 composite_basis.set_input_shape(counts, speed) -out = transformer_bas.fit_transform(inp) +out = composite_basis.to_transformer().fit_transform(inp) ``` If the input is 1D or 2D, the number of columns, ```{code-cell} ipython3 trans_bas = composite_basis.set_input_shape(5, 1).transformer() -out = transformer_bas.fit_transform(inp) +out = composite_basis.to_transformer().fit_transform(inp) ``` A tuple containing the shapes of all axis other than the first, ```{code-cell} ipython3 composite_basis.set_input_shape((5,), (1,)) -out = transformer_bas.fit_transform(inp) +out = composite_basis.to_transformer().fit_transform(inp) ``` -Or a mix of any of the above. - +Or a mix of the above. ```{code-cell} ipython3 composite_basis.set_input_shape(counts, 1) -out = transformer_bas.fit_transform(inp) +out = composite_basis.to_transformer().fit_transform(inp) ``` You can also invert the order and call `to_transform` first and set the input shapes after. ```{code-cell} ipython3 -trans_bas = composite_basis.transformer() +trans_bas = composite_basis.to_transformer() trans_bas.set_input_shape(5, 1) -out = transformer_bas.fit_transform(inp) -``` \ No newline at end of file +out = trans_bas.fit_transform(inp) +``` + +:::{note} + +If you define a NeMoS basis and call `compute_features` on your inputs, internally, the basis will store the +input shapes, and the `TransformerBasis` will be ready to process without any direct call to `set_input_shape`. +::: + +If for some reason you will need to provide an input of different shape to the transformer, you must setup the +`TransformerBasis` again. + +```{code-cell} ipython3 + +# define inputs with different shapes and concatenate +x, y = np.random.poisson(size=(10, 3)), np.random.randn(10, 2, 3) +inp2 = np.concatenate([x, y.reshape(10, 6)], axis=1) + +trans_bas = composite_basis.to_transformer() +trans_bas.set_input_shape(3, (2, 3)) +out2 = trans_bas.fit_transform(inp2) +``` + + +### Learn more + +If you want to learn more about basis hyperparameters with `sklearn` pipelining and cross-validation, check out [this guide](sklearn-how-to). diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index a049b9fd..47f1edcb 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -79,7 +79,6 @@ def basis(self): @basis.setter def basis(self, basis): - self._check_initialized(basis) self._basis = basis def _unpack_inputs(self, X: FeatureMatrix) -> List: @@ -140,6 +139,8 @@ def fit(self, X: FeatureMatrix, y=None): >>> transformer = TransformerBasis(basis) >>> transformer_fitted = transformer.fit(X) """ + self._check_initialized(self._basis) + self._check_input(X, y) self._basis.setup_basis(*self._unpack_inputs(X)) return self @@ -180,6 +181,7 @@ def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: >>> # Transform basis >>> feature_transformed = transformer.transform(X) """ + self._check_initialized(self._basis) # transpose does not work with pynapple # can't use func(*X.T) to unwrap return self._basis._compute_features(*self._unpack_inputs(X)) @@ -219,7 +221,8 @@ def fit_transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: >>> # Fit and transform basis >>> feature_transformed = transformer.fit_transform(X) """ - return self._basis.compute_features(*self._unpack_inputs(X)) + self.fit(X, y=y) + return self.transform(X) def __getstate__(self): """ @@ -416,3 +419,41 @@ def __pow__(self, exponent: int) -> TransformerBasis: """ # errors are handled by Basis.__pow__ return TransformerBasis(self._basis**exponent) + + def _check_input(self, X: FeatureMatrix, y=None): + """Check that the input structure. + + TransformerBasis expects a 2-d array as an input. The number of columns should match the number of inputs + the basis expects. This number can be set before the TransformerBasis is initialized, by calling + ``Basis.set_input_shape``. + + Parameters + ---------- + X: + The input FeatureMatrix. + + Raises + ------ + ValueError: + If the input is not a 2-d array or if the number of columns does not match the expected number of inputs. + """ + ndim = getattr(X, "ndim", None) + if ndim is None: + raise ValueError("The input must be a 2-dimensional array.") + + elif ndim != 2: + raise ValueError( + f"X must be 2-dimensional, shape (n_samples, n_features). The provided X has shape {X.shape} instead." + ) + + if X.shape[1] != sum(self.n_basis_input_): + raise ValueError( + f"Input mismatch: expected {sum(self.n_basis_input_)} inputs, but got {X.shape[1]} columns in X.\n" + "To modify the required number of inputs, call `set_input_shape` before using `fit` or `fit_transform`." + ) + + if y is not None and y.shape[0] != X.shape[0]: + raise ValueError( + "X and y must have the same number of samples. " + f"X has {X.shpae[0]} samples, while y has {y.shape[0]} samples." + ) From 23a276c81d1c0c7f3aa57c16f5a500b80a015986 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 13 Dec 2024 18:57:54 -0500 Subject: [PATCH 17/41] note on transformer --- docs/how_to_guide/plot_05_transformer_basis.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/how_to_guide/plot_05_transformer_basis.md b/docs/how_to_guide/plot_05_transformer_basis.md index 7df68de5..9a81dffc 100644 --- a/docs/how_to_guide/plot_05_transformer_basis.md +++ b/docs/how_to_guide/plot_05_transformer_basis.md @@ -167,4 +167,5 @@ out2 = trans_bas.fit_transform(inp2) ### Learn more -If you want to learn more about basis hyperparameters with `sklearn` pipelining and cross-validation, check out [this guide](sklearn-how-to). +If you want to learn more about basis how to select basis hyperparameters with `sklearn` pipelining and cross-validation, check out [this guide](sklearn-how-to). + From 1d8119906c0492dcad714c93929aafa9316f013c Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 13 Dec 2024 18:59:42 -0500 Subject: [PATCH 18/41] renamed tutorials --- ...rn_pipeline_cv_demo.md => plot_06_sklearn_pipeline_cv_demo.md} | 0 .../how_to_guide/{plot_06_glm_pytree.md => plot_07_glm_pytree.md} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename docs/how_to_guide/{plot_05_sklearn_pipeline_cv_demo.md => plot_06_sklearn_pipeline_cv_demo.md} (100%) rename docs/how_to_guide/{plot_06_glm_pytree.md => plot_07_glm_pytree.md} (100%) diff --git a/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md similarity index 100% rename from docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md rename to docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md diff --git a/docs/how_to_guide/plot_06_glm_pytree.md b/docs/how_to_guide/plot_07_glm_pytree.md similarity index 100% rename from docs/how_to_guide/plot_06_glm_pytree.md rename to docs/how_to_guide/plot_07_glm_pytree.md From 0340a3b3dce7e56676c72f91d7611cc00bdb7d6c Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 13 Dec 2024 20:16:30 -0500 Subject: [PATCH 19/41] fixed docs --- docs/how_to_guide/README.md | 23 +- .../how_to_guide/plot_05_transformer_basis.md | 26 +- .../plot_06_sklearn_pipeline_cv_demo.md | 45 +-- docs/how_to_guide/plot_07_glm_pytree.md | 2 +- docs/quickstart.md | 320 +++++++++--------- src/nemos/basis/_transformer_basis.py | 2 +- 6 files changed, 219 insertions(+), 199 deletions(-) diff --git a/docs/how_to_guide/README.md b/docs/how_to_guide/README.md index 1a3ec6c8..2a33008f 100644 --- a/docs/how_to_guide/README.md +++ b/docs/how_to_guide/README.md @@ -14,7 +14,7 @@ pip install nemos[examples] ::: -::::{grid} 1 2 3 3 +::::{grid} 1 2 3 4 :::{grid-item-card} @@ -58,20 +58,34 @@ plot_04_batch_glm.md :::{grid-item-card}
-Pipelining and cross-validation. +NeMoS vs sklearn.
```{toctree} :maxdepth: 2 -plot_05_sklearn_pipeline_cv_demo.md +plot_05_transformer_basis.md ``` ::: :::{grid-item-card}
-PyTrees. +PyTrees. +
+ +```{toctree} +:maxdepth: 2 + +plot_06_sklearn_pipeline_cv_demo.md +``` + +::: + +:::{grid-item-card} + +
+PyTrees.
```{toctree} @@ -79,6 +93,7 @@ plot_05_sklearn_pipeline_cv_demo.md plot_06_glm_pytree.md ``` + ::: :::: diff --git a/docs/how_to_guide/plot_05_transformer_basis.md b/docs/how_to_guide/plot_05_transformer_basis.md index 9a81dffc..fa3e99f1 100644 --- a/docs/how_to_guide/plot_05_transformer_basis.md +++ b/docs/how_to_guide/plot_05_transformer_basis.md @@ -1,5 +1,19 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +kernelspec: + display_name: Python 3 + language: python + name: python3 +--- + # Converting NeMoS Bases To scikit-learn Transformers +(tansformer-vs-nemos-basis)= ## scikit-learn Transformers and NeMoS Basis `scikit-learn` is a great machine learning package that provides advanced tooling for creating data analysis pipelines, from input transformations to model fitting and cross-validation. @@ -24,7 +38,7 @@ Wouldn't it be great if one could combine them? Well, this is what NeMoS `Transf With NeMoS, you can easily create a basis accepting two inputs. Let's assume that we want to process the neural activity as a 2-dimensional spike count array of shape `(n_samples, n_neurons)` and a second array with the speed of an animal of shape `(n_samples,)`. -```{code-block} ipython3 +```{code-cell} ipython3 import numpy as np import nemos as nmo @@ -46,9 +60,9 @@ X = composite_basis.compute_features(counts, speed) ### Converting NeMoS `Basis` to a transformer Now, imagine that we want to use this basis as a step in a `scikit-learn` pipeline. -In this standard (for NeMoS) form, it would not be possible the `composite_basis` object requires two inputs. We need to convert it first into a compliant scikit-learn transformer. This can be achieved through the [`TransformerBasis`](nemos.basis._trans_basis.TransformerBasis) wrapper class. +In this standard (for NeMoS) form, it would not be possible the `composite_basis` object requires two inputs. We need to convert it first into a compliant scikit-learn transformer. This can be achieved through the [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) wrapper class. -Instantiating a [`TransformerBasis`](nemos.basis._trans_basis.TransformerBasis) can be done either using by the constructor directly or with [`Basis.to_transformer()`](nemos.basis._basis.Basis.to_transformer): +Instantiating a [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) can be done either using by the constructor directly or with [`Basis.to_transformer()`](nemos.basis._basis.Basis.to_transformer): ```{code-cell} ipython3 @@ -62,14 +76,14 @@ trans_bas = bas.to_transformer() ``` -[`TransformerBasis`](nemos.basis._trans_basis.TransformerBasis) provides convenient access to the underlying [`Basis`](nemos.basis._basis.Basis) object's attributes: +[`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) provides convenient access to the underlying [`Basis`](nemos.basis._basis.Basis) object's attributes: ```{code-cell} ipython3 print(bas.n_basis_funcs, trans_bas.n_basis_funcs) ``` -We can also set attributes of the underlying [`Basis`](nemos.basis._basis.Basis). Note that -- because [`TransformerBasis`](nemos.basis._trans_basis.TransformerBasis) is created with a copy of the [`Basis`](nemos.basis._basis.Basis) object passed to it -- this does not change the original [`Basis`](nemos.basis._basis.Basis), and neither does changing the original [`Basis`](nemos.basis._basis.Basis) change [`TransformerBasis`](nemos.basis._trans_basis.TransformerBasis) we created: +We can also set attributes of the underlying [`Basis`](nemos.basis._basis.Basis). Note that -- because [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) is created with a copy of the [`Basis`](nemos.basis._basis.Basis) object passed to it -- this does not change the original [`Basis`](nemos.basis._basis.Basis), and neither does changing the original [`Basis`](nemos.basis._basis.Basis) change [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) we created: ```{code-cell} ipython3 @@ -118,7 +132,7 @@ out = composite_basis.to_transformer().fit_transform(inp) If the input is 1D or 2D, the number of columns, ```{code-cell} ipython3 -trans_bas = composite_basis.set_input_shape(5, 1).transformer() +composite_basis.set_input_shape(5, 1) out = composite_basis.to_transformer().fit_transform(inp) ``` diff --git a/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md index 461f1b31..2afa68e8 100644 --- a/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md +++ b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md @@ -150,10 +150,9 @@ ax.set_xlabel("input") ax.set_ylabel("spike count") sns.despine(ax=ax) ``` -### Converting NeMoS `Basis` to a transformer -In order to use NeMoS [`Basis`](nemos.basis._basis.Basis) in a pipeline, we need to convert it into a scikit-learn transformer. This can be achieved through the [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) wrapper class. -Instantiating a [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) can be done either using by the constructor directly or with [`Basis.to_transformer()`](nemos.basis._basis.Basis.to_transformer): +### Converting NeMoS `Basis` to a transformer +In order to use NeMoS [`Basis`](nemos.basis._basis.Basis) in a pipeline, we need to convert it into a scikit-learn transformer. ```{code-cell} ipython3 @@ -165,24 +164,16 @@ trans_bas = nmo.basis.TransformerBasis(bas) # equivalent initialization via "to_transformer" trans_bas = bas.to_transformer() +# setup the transformer +trans_bas.set_input_shape(1) ``` -[`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) provides convenient access to the underlying [`Basis`](nemos.basis._basis.Basis) object's attributes: - +:::{admonition} Learn More about `TransformerBasis` +:note: -```{code-cell} ipython3 -print(bas.n_basis_funcs, trans_bas.n_basis_funcs) -``` - -We can also set attributes of the underlying [`Basis`](nemos.basis._basis.Basis). Note that -- because [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) is created with a copy of the [`Basis`](nemos.basis._basis.Basis) object passed to it -- this does not change the original [`Basis`](nemos.basis._basis.Basis), and neither does changing the original [`Basis`](nemos.basis._basis.Basis) change [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) we created: - - -```{code-cell} ipython3 -trans_bas.n_basis_funcs = 10 -bas.n_basis_funcs = 100 +To learn more about `sklearn` transformers and `TransforerBasis`, check out [this note](tansformer-vs-nemos-basis). +::: -print(bas.n_basis_funcs, trans_bas.n_basis_funcs) -``` ### Creating and fitting a pipeline We might want to combine first transforming the input data with our basis functions, then fitting a GLM on the transformed data. @@ -194,7 +185,7 @@ pipeline = Pipeline( [ ( "transformerbasis", - nmo.basis.RaisedCosineLinearEval(6).to_transformer(), + nmo.basis.RaisedCosineLinearEval(6).set_input_shape(1).to_transformer(), ), ( "glm", @@ -311,7 +302,7 @@ gridsearch.fit(X, y) To appreciate how much boiler-plate code we are saving by calling scikit-learn cross-validation, below we can see how this cross-validation will look like in a manual loop. -```python +```{code} ipython from itertools import product from copy import deepcopy @@ -439,7 +430,7 @@ if root or Path("../assets/stylesheets").exists(): path.mkdir(parents=True, exist_ok=True) if path.exists(): - fig.savefig(path / "plot_05_sklearn_pipeline_cv_demo.svg") + fig.savefig(path / "plot_06_sklearn_pipeline_cv_demo.svg") ``` 🚀🚀🚀 **Success!** 🚀🚀🚀 @@ -457,12 +448,12 @@ Here we include `transformerbasis___basis` in the parameter grid to try differen param_grid = dict( glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6), transformerbasis___basis=( - nmo.basis.RaisedCosineLinearEval(5), - nmo.basis.RaisedCosineLinearEval(10), - nmo.basis.RaisedCosineLogEval(5), - nmo.basis.RaisedCosineLogEval(10), - nmo.basis.MSplineEval(5), - nmo.basis.MSplineEval(10), + nmo.basis.RaisedCosineLinearEval(5).set_input_shape(1), + nmo.basis.RaisedCosineLinearEval(10).set_input_shape(1), + nmo.basis.RaisedCosineLogEval(5).set_input_shape(1), + nmo.basis.RaisedCosineLogEval(10).set_input_shape(1), + nmo.basis.MSplineEval(5).set_input_shape(1), + nmo.basis.MSplineEval(10).set_input_shape(1), ), ) ``` @@ -538,7 +529,7 @@ The plot confirms that the firing rate distribution is accurately captured by ou :::{warning} Please note that because it would lead to unexpected behavior, mixing the two ways of defining values for the parameter grid is not allowed. The following would lead to an error: -```python +```{code} ipython param_grid = dict( glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6), transformerbasis__n_basis_funcs=(3, 5, 10, 20, 100), diff --git a/docs/how_to_guide/plot_07_glm_pytree.md b/docs/how_to_guide/plot_07_glm_pytree.md index d6608c45..f980e75c 100644 --- a/docs/how_to_guide/plot_07_glm_pytree.md +++ b/docs/how_to_guide/plot_07_glm_pytree.md @@ -265,7 +265,7 @@ if root or Path("../assets/stylesheets").exists(): path.mkdir(parents=True, exist_ok=True) if path.exists(): - fig.savefig(path / "plot_06_glm_pytree.svg") + fig.savefig(path / "plot_07_glm_pytree.svg") ``` diff --git a/docs/quickstart.md b/docs/quickstart.md index bdf3ffd4..f071839f 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -1,7 +1,18 @@ --- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +kernelspec: + display_name: Python 3 + language: python + name: python3 hide: - navigation --- + # Quickstart ## **Overview** @@ -29,58 +40,56 @@ NeMoS provides two implementations of the GLM: one for fitting a single neuron, You can define a single neuron GLM by instantiating an `GLM` object. -```python +```{code-cell} ipython3 ->>> import nemos as nmo +import nemos as nmo ->>> # Instantiate the single model ->>> model = nmo.glm.GLM() +# Instantiate the single model +model = nmo.glm.GLM() ``` The coefficients can be learned by invoking the `fit` method of `GLM`. The method requires a design matrix of shape `(num_samples, num_features)`, and the output neural activity of shape `(num_samples, )`. -```python +```{code-cell} ipython3 ->>> import numpy as np ->>> num_samples, num_features = 100, 3 +import numpy as np +num_samples, num_features = 100, 3 ->>> # Generate a design matrix ->>> X = np.random.normal(size=(num_samples, num_features)) ->>> # generate some counts ->>> spike_counts = np.random.poisson(size=num_samples) +# Generate a design matrix +X = np.random.normal(size=(num_samples, num_features)) +# generate some counts +spike_counts = np.random.poisson(size=num_samples) ->>> # define fit the model ->>> model = model.fit(X, spike_counts) +# define fit the model +model = model.fit(X, spike_counts) ``` Once the model is fit, you can retrieve the model parameters as shown below. -```python ->>> # model coefficients shape is (num_features, ) ->>> print(f"Model coefficients shape: {model.coef_.shape}") -Model coefficients shape: (3,) +```{code-cell} ipython3 +# model coefficients shape is (num_features, ) +print(f"Model coefficients shape: {model.coef_.shape}") ->>> # model intercept, shape (1,) since there is only one neuron. ->>> print(f"Model intercept shape: {model.intercept_.shape}") -Model intercept shape: (1,) +# model intercept, shape (1,) since there is only one neuron. +print(f"Model intercept shape: {model.intercept_.shape}") ``` Additionally, you can predict the firing rate and call the compute the model log-likelihood by calling the `predict` and the `score` method respectively. -```python +```{code-cell} ipython3 + +# predict the rate +predicted_rate = model.predict(X) +# firing rate has shape: (num_samples,) +predicted_rate.shape ->>> # predict the rate ->>> predicted_rate = model.predict(X) ->>> # firing rate has shape: (num_samples,) ->>> predicted_rate.shape -(100,) ->>> # compute the log-likelihood of the model ->>> log_likelihood = model.score(X, spike_counts) +# compute the log-likelihood of the model +log_likelihood = model.score(X, spike_counts) ``` @@ -88,49 +97,47 @@ Additionally, you can predict the firing rate and call the compute the model log You can set up a population GLM by instantiating a `PopulationGLM`. The API for the `PopulationGLM` is the same as for the single-neuron `GLM`; the only difference you'll notice is that some of the methods' inputs and outputs have an additional dimension for the different neurons. -```python +```{code-cell} ->>> import nemos as nmo ->>> population_model = nmo.glm.PopulationGLM() +import nemos as nmo +population_model = nmo.glm.PopulationGLM() ``` As for the single neuron GLM, you can learn the model parameters by invoking the `fit` method: the input of `fit` are the design matrix (with shape `(num_samples, num_features)` ), and the population activity (with shape `(num_samples, num_neurons)`). Once the model is fit, you can use `predict` and `score` to predict the firing rate and the log-likelihood. -```python +```{code-cell} ->>> import numpy as np ->>> num_samples, num_features, num_neurons = 100, 3, 5 +import numpy as np +num_samples, num_features, num_neurons = 100, 3, 5 ->>> # simulate a design matrix ->>> X = np.random.normal(size=(num_samples, num_features)) ->>> # simulate some counts ->>> spike_counts = np.random.poisson(size=(num_samples, num_neurons)) +# simulate a design matrix +X = np.random.normal(size=(num_samples, num_features)) +# simulate some counts +spike_counts = np.random.poisson(size=(num_samples, num_neurons)) ->>> # fit the model ->>> population_model = population_model.fit(X, spike_counts) +# fit the model +population_model = population_model.fit(X, spike_counts) ->>> # predict the rate of each neuron in the population ->>> predicted_rate = population_model.predict(X) ->>> predicted_rate.shape # expected shape: (num_samples, num_neurons) -(100, 5) +# predict the rate of each neuron in the population +predicted_rate = population_model.predict(X) +predicted_rate.shape # expected shape: (num_samples, num_neurons) ->>> # compute the log-likelihood of the model ->>> log_likelihood = population_model.score(X, spike_counts) + +# compute the log-likelihood of the model +log_likelihood = population_model.score(X, spike_counts) ``` The learned coefficient and intercept will have shape `(num_features, num_neurons)` and `(num_neurons, )` respectively. -```python ->>> # model coefficients shape is (num_features, num_neurons) ->>> print(f"Model coefficients shape: {population_model.coef_.shape}") -Model coefficients shape: (3, 5) +```{code-cell} +# model coefficients shape is (num_features, num_neurons) +print(f"Model coefficients shape: {population_model.coef_.shape}") ->>> # model intercept, (num_neurons,) ->>> print(f"Model intercept shape: {population_model.intercept_.shape}") -Model intercept shape: (5,) +# model intercept, (num_neurons,) +print(f"Model intercept shape: {population_model.intercept_.shape}") ``` @@ -160,12 +167,12 @@ The `basis` module includes objects that perform two types of transformations on Non-linear mapping is the default mode of operation of any `basis` object. To instantiate a basis for non-linear mapping, you need to specify the number of basis functions. For some `basis` objects, additional arguments may be required (see the [API Reference](nemos_basis) for detailed information). -```python +```{code-cell} ->>> import nemos as nmo +import nemos as nmo ->>> n_basis_funcs = 10 ->>> basis = nmo.basis.RaisedCosineLinearEval(n_basis_funcs) +n_basis_funcs = 10 +basis = nmo.basis.RaisedCosineLinearEval(n_basis_funcs) ``` @@ -173,17 +180,16 @@ Once the basis is instantiated, you can apply it to your input data using the `c This method takes an input array of shape `(n_samples, )` and transforms it into a two-dimensional array of shape `(n_samples, n_basis_funcs)`, where each column represents a feature generated by the non-linear mapping. -```python +```{code-cell} ->>> import numpy as np +import numpy as np ->>> # generate an input ->>> x = np.arange(100) +# generate an input +x = np.arange(100) ->>> # evaluate the basis ->>> X = basis.compute_features(x) ->>> X.shape -(100, 10) +# evaluate the basis +X = basis.compute_features(x) +X.shape ``` @@ -199,13 +205,13 @@ If you want to convolve a bank of basis functions with an input you must set the `"conv"` and you must provide an integer `window_size` parameter, which defines the length of the filter bank in number of sample points. -```python +```{code-cell} ipython3 ->>> import nemos as nmo +import nemos as nmo ->>> n_basis_funcs = 10 ->>> # define a filter bank of 10 basis function, 200 samples long. ->>> basis = nmo.basis.BSplineConv(n_basis_funcs, window_size=200) +n_basis_funcs = 10 +# define a filter bank of 10 basis function, 200 samples long. +basis = nmo.basis.BSplineConv(n_basis_funcs, window_size=200) ``` @@ -219,23 +225,21 @@ Once the basis is initialized, you can call `compute_features` on an input of sh The `window_size` must be shorter than the number of samples in the signal(s) being convolved. ::: -```python +```{code-cell} ipython3 ->>> import numpy as np +import numpy as np ->>> x = np.ones(500) +x = np.ones(500) ->>> # convolve a single signal ->>> X = basis.compute_features(x) ->>> X.shape -(500, 10) +# convolve a single signal +X = basis.compute_features(x) +X.shape ->>> x_multi = np.ones((500, 3)) +x_multi = np.ones((500, 3)) ->>> # convolve a multiple signals ->>> X_multi = basis.compute_features(x_multi) ->>> X_multi.shape -(500, 30) +# convolve a multiple signals +X_multi = basis.set_input_shape(3).compute_features(x_multi) +X_multi.shape ``` @@ -249,12 +253,12 @@ By default, NeMoS' GLM uses [Poisson observations](nemos.observation_models.Pois To change the default observation model, set the `observation_model` argument during initialization: -```python +```{code-cell} ipython3 ->>> import nemos as nmo +import nemos as nmo ->>> # set up a Gamma GLM for modeling continuous non-negative data ->>> glm = nmo.glm.GLM(observation_model=nmo.observation_models.GammaObservations()) +# set up a Gamma GLM for modeling continuous non-negative data +glm = nmo.glm.GLM(observation_model=nmo.observation_models.GammaObservations()) ``` @@ -270,12 +274,12 @@ NeMoS supports various regularization schemes, including [Ridge](nemos.regulariz You can specify the regularization scheme and its strength when initializing the GLM model: -```python +```{code-cell} ipython3 ->>> import nemos as nmo +import nemos as nmo ->>> # Instantiate a GLM with Ridge (L2) regularization ->>> glm = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1) +# Instantiate a GLM with Ridge (L2) regularization +glm = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1) ``` @@ -301,25 +305,26 @@ also be a `pynapple` time series. A canonical example of this behavior is the `predict` method of `GLM`. -```ipython +```{code-cell} ipython3 + +import numpy as np +import pynapple as nap ->>> import numpy as np ->>> import pynapple as nap +# suppress jax to numpy conversion warning +nap.nap_config.suppress_conversion_warnings = True ->>> # create a TsdFrame with the features and a Tsd with the counts ->>> X = nap.TsdFrame(t=np.arange(100), d=np.random.normal(size=(100, 2))) ->>> y = nap.Tsd(t=np.arange(100), d=np.random.poisson(size=(100, ))) +# create a TsdFrame with the features and a Tsd with the counts +X = nap.TsdFrame(t=np.arange(100), d=np.random.normal(size=(100, 2))) +y = nap.Tsd(t=np.arange(100), d=np.random.poisson(size=(100, ))) ->>> print(type(X)) # shape (num samples, num features) - +print(type(X)) # shape (num samples, num features) ->>> model = model.fit(X, y) # the following works +model = model.fit(X, y) # the following works ->>> firing_rate = model.predict(X) # predict the firing rate of the neuron +firing_rate = model.predict(X) # predict the firing rate of the neuron ->>> # this will still be a pynapple time series ->>> print(type(firing_rate)) # shape (num_samples, ) - +# this will still be a pynapple time series +print(type(firing_rate)) # shape (num_samples, ) ``` @@ -331,29 +336,29 @@ Let's see how you can greatly streamline your analysis pipeline by integrating ` You can download this dataset by clicking [here](https://www.dropbox.com/s/su4oaje57g3kit9/A2929-200711.zip?dl=1). ::: -```ipython +```{code-cell} ipython3 ->>> import nemos as nmo ->>> import pynapple as nap +import nemos as nmo +import pynapple as nap ->>> path = nmo.fetch.fetch_data("A2929-200711.nwb") ->>> data = nap.load_file(path) +path = nmo.fetch.fetch_data("A2929-200711.nwb") +data = nap.load_file(path) ->>> # load spikes and head direction ->>> spikes = data["units"] ->>> head_dir = data["ry"] +# load spikes and head direction +spikes = data["units"] +head_dir = data["ry"] ->>> # restrict and bin ->>> counts = spikes[6].count(0.01, ep=head_dir.time_support) +# restrict and bin +counts = spikes[6].count(0.01, ep=head_dir.time_support) ->>> # down-sample head direction ->>> upsampled_head_dir = head_dir.bin_average(0.01) +# down-sample head direction +upsampled_head_dir = head_dir.bin_average(0.01) ->>> # create your features ->>> X = nmo.basis.CyclicBSplineEval(10).compute_features(upsampled_head_dir) +# create your features +X = nmo.basis.CyclicBSplineEval(10).compute_features(upsampled_head_dir) ->>> # add a neuron axis and fit model ->>> model = nmo.glm.GLM().fit(X, counts) +# add a neuron axis and fit model +model = nmo.glm.GLM().fit(X, counts) ``` @@ -361,35 +366,31 @@ You can download this dataset by clicking [here](https://www.dropbox.com/s/su4oa Finally, let's compare the tuning curves -```ipython +```{code-cell} ipython3 ->>> import numpy as np ->>> import matplotlib.pyplot as plt +import numpy as np +import matplotlib.pyplot as plt ->>> # tuning curves ->>> raw_tuning = nap.compute_1d_tuning_curves(spikes, head_dir, nb_bins=100)[6] +# tuning curves +raw_tuning = nap.compute_1d_tuning_curves(spikes, head_dir, nb_bins=100)[6] ->>> # model based tuning curve ->>> model_tuning = nap.compute_1d_tuning_curves_continuous( -... model.predict(X)[:, np.newaxis] * X.rate, # scale by the sampling rate -... head_dir, -... nb_bins=100 -... )[0] +# model based tuning curve +model_tuning = nap.compute_1d_tuning_curves_continuous( + model.predict(X)[:, np.newaxis] * X.rate, # scale by the sampling rate + head_dir, + nb_bins=100 + )[0] ->>> # plot results ->>> sub = plt.subplot(111, projection="polar") ->>> plt1 = plt.plot(raw_tuning.index, raw_tuning.values, label="raw") ->>> plt2 = plt.plot(model_tuning.index, model_tuning.values, label="glm") ->>> legend = plt.yticks([]) ->>> xlab = plt.xlabel("heading angle") +# plot results +sub = plt.subplot(111, projection="polar") +plt1 = plt.plot(raw_tuning.index, raw_tuning.values, label="raw") +plt2 = plt.plot(model_tuning.index, model_tuning.values, label="glm") +legend = plt.yticks([]) +xlab = plt.xlabel("heading angle") ``` - - - - ## **Compatibility with `scikit-learn`** @@ -400,34 +401,34 @@ For example, if we would like to tune the critical hyper-parameter `regularizer_ [^1]: For a detailed explanation and practical examples, refer to the [cross-validation page](https://scikit-learn.org/stable/modules/cross_validation.html) in the `scikit-learn` documentation. -```ipython +```{code-cell} ipython3 ->>> # set up the model ->>> import nemos as nmo ->>> import numpy as np +# set up the model +import nemos as nmo +import numpy as np ->>> # generate data ->>> X, counts = np.random.normal(size=(100, 3)), np.random.poisson(size=100) +# generate data +X, counts = np.random.normal(size=(100, 3)), np.random.poisson(size=100) ->>> # model definition ->>> model = nmo.glm.GLM(regularizer="Ridge") +# model definition +model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1) ``` Fit a 5-fold cross-validation scheme for comparing two different regularizer strengths: -```ipython +```{code-cell} ipython3 ->>> from sklearn.model_selection import GridSearchCV +from sklearn.model_selection import GridSearchCV ->>> # define the parameter grid ->>> param_grid = dict(regularizer_strength=(0.01, 0.001)) +# define the parameter grid +param_grid = dict(regularizer_strength=(0.01, 0.001)) ->>> # define the 5-fold cross-validation grid search from sklearn ->>> cls = GridSearchCV(model, param_grid=param_grid, cv=5) +# define the 5-fold cross-validation grid search from sklearn +cls = GridSearchCV(model, param_grid=param_grid, cv=5) ->>> # run the 5-fold cross-validation grid search ->>> cls = cls.fit(X, counts) +# run the 5-fold cross-validation grid search +cls = cls.fit(X, counts) ``` @@ -440,11 +441,10 @@ For more information and a practical example on how to construct a parameter gri Finally, we can print the regularizer strength with the best cross-validated performance: -```ipython +```{code-cell} ipython3 ->>> # print best regularizer strength ->>> print(cls.best_params_) -{'regularizer_strength': 0.01} +# print best regularizer strength +print(cls.best_params_) ``` diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index 47f1edcb..16c01e20 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -312,7 +312,7 @@ def __sklearn_clone__(self) -> TransformerBasis: For more info: https://scikit-learn.org/stable/developers/develop.html#cloning """ - cloned_obj = TransformerBasis(copy.deepcopy(self._basis)) + cloned_obj = TransformerBasis(self._basis.__sklearn_clone__()) cloned_obj._basis.kernel_ = None return cloned_obj From 9629f7fc82a66567a430fdae10d4c23f7a1424c2 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 13 Dec 2024 20:49:54 -0500 Subject: [PATCH 20/41] edited docs and linted --- docs/how_to_guide/README.md | 2 +- .../how_to_guide/plot_05_transformer_basis.md | 58 +++++++++---------- docs/index.md | 2 +- src/nemos/basis/_basis_mixin.py | 2 +- 4 files changed, 31 insertions(+), 33 deletions(-) diff --git a/docs/how_to_guide/README.md b/docs/how_to_guide/README.md index 2a33008f..432eb800 100644 --- a/docs/how_to_guide/README.md +++ b/docs/how_to_guide/README.md @@ -91,7 +91,7 @@ plot_06_sklearn_pipeline_cv_demo.md ```{toctree} :maxdepth: 2 -plot_06_glm_pytree.md +plot_07_glm_pytree.md ``` ::: diff --git a/docs/how_to_guide/plot_05_transformer_basis.md b/docs/how_to_guide/plot_05_transformer_basis.md index fa3e99f1..9b8fe6b6 100644 --- a/docs/how_to_guide/plot_05_transformer_basis.md +++ b/docs/how_to_guide/plot_05_transformer_basis.md @@ -16,27 +16,25 @@ kernelspec: (tansformer-vs-nemos-basis)= ## scikit-learn Transformers and NeMoS Basis -`scikit-learn` is a great machine learning package that provides advanced tooling for creating data analysis pipelines, from input transformations to model fitting and cross-validation. +`scikit-learn` is a powerful machine learning library that provides advanced tools for creating data analysis pipelines, from input transformations to model fitting and cross-validation. -All of `scikit-learn` machinery relies on very strong assumptions on how one should structure the inputs to each processing step. -In particular, all `scikit-learn` objects requires inputs in the form of arrays of at most two-dimensions, where the first dimension always represents time (or samples) dimension, and the other features. -This may feel a bit rigid at first, but what this buys you is that any transformation can be chained to any other, greatly simplifying the process of building stable complex pipelines. +All of `scikit-learn`'s machinery relies on strict assumptions about input structure. In particular, all `scikit-learn` +objects require inputs as arrays of at most two dimensions, where the first dimension represents the time (or samples) +axis, and the second dimension represents features. +While this may feel rigid, it enables transformations to be seamlessly chained together, greatly simplifying the +process of building stable, complex pipelines. -In `scikit-learn`, the data transformation steps are performed by object called `transformers`. +On the other hand, `NeMoS` takes a different approach to feature construction. `NeMoS`' bases are composable constructors that allow for more flexibility in the required input structure. +Depending on the basis type, it can accept one or more input arrays or `pynapple` time series data, each of which can take any shape as long as the time (or sample) axis is the first of each array; +`NeMoS` design favours object composability: one can combine any two or more bases to compute complex features, with a user-friendly interface that can accept a separate array/time series for each input type (e.g., an array with the spike counts, an array for the animal's position, etc.). - -On the other hand, NeMoS basis are powerful feature constructors that allow a high degree of flexibility in terms of the required input structure. -Depending on the basis type, it can accept one or more input arrays or `pynapple` time series data, each of which can have any shape as long as the time (or sample) axis is the first of each array; -NeMoS design favours object composability, one can combine any two or more bases to compute complex features, and a user-friendly interface can accept a separate array/time series for each input type (e.g., an array with the spike counts, an array for the animal's position, etc.). - -Both approaches to data transformations are valuable and have their own advantages. -Wouldn't it be great if one could combine them? Well, this is what NeMoS `TransformerBasis` are for! +Both approaches to data transformation are valuable and each has its own advantages. Wouldn't it be great if one could combine the two? Well, this is what NeMoS `TransformerBasis` is for! ## From Basis to TransformerBasis -With NeMoS, you can easily create a basis accepting two inputs. Let's assume that we want to process the neural activity as a 2-dimensional spike count array of shape `(n_samples, n_neurons)` and a second array with the speed of an animal of shape `(n_samples,)`. +With NeMoS, you can easily create a basis accepting two inputs. Let's assume that we want to process neural activity stored in a 2-dimensional spike count array of shape `(n_samples, n_neurons)` and a second array containing the speed of an animal, with shape `(n_samples,)`. ```{code-cell} ipython3 import numpy as np @@ -60,9 +58,9 @@ X = composite_basis.compute_features(counts, speed) ### Converting NeMoS `Basis` to a transformer Now, imagine that we want to use this basis as a step in a `scikit-learn` pipeline. -In this standard (for NeMoS) form, it would not be possible the `composite_basis` object requires two inputs. We need to convert it first into a compliant scikit-learn transformer. This can be achieved through the [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) wrapper class. +In this standard (for NeMoS) form, it would not be possible as the `composite_basis` object requires two inputs. We need to convert it first into a compliant `scikit-learn` transformer. This can be achieved through the [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) wrapper class. -Instantiating a [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) can be done either using by the constructor directly or with [`Basis.to_transformer()`](nemos.basis._basis.Basis.to_transformer): +Instantiating a [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) can be done either by using the constructor directly or with [`Basis.to_transformer()`](nemos.basis._basis.Basis.to_transformer): ```{code-cell} ipython3 @@ -83,7 +81,7 @@ trans_bas = bas.to_transformer() print(bas.n_basis_funcs, trans_bas.n_basis_funcs) ``` -We can also set attributes of the underlying [`Basis`](nemos.basis._basis.Basis). Note that -- because [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) is created with a copy of the [`Basis`](nemos.basis._basis.Basis) object passed to it -- this does not change the original [`Basis`](nemos.basis._basis.Basis), and neither does changing the original [`Basis`](nemos.basis._basis.Basis) change [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) we created: +We can also set attributes of the underlying [`Basis`](nemos.basis._basis.Basis). Note that -- because [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) is created with a copy of the [`Basis`](nemos.basis._basis.Basis) object passed to it -- this does not change the original [`Basis`](nemos.basis._basis.Basis), nor does changing the original [`Basis`](nemos.basis._basis.Basis) modify the [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) we created: ```{code-cell} ipython3 @@ -93,11 +91,11 @@ bas.n_basis_funcs = 100 print(bas.n_basis_funcs, trans_bas.n_basis_funcs) ``` -As any `sckit-learn` tansformer, the `TransformerBasis` implements `fit`, a preparation step, `transform`, the actual feature computation, and `fit_transform` which chains `fit` and `transform`. These methods comply with the `scikit-learn` input structure convention, and therefore all accepts a single 2D array. +As with any `sckit-learn` transformer, the `TransformerBasis` implements `fit`, a preparation step, `transform`, the actual feature computation, and `fit_transform` which chains `fit` and `transform`. These methods comply with the `scikit-learn` input structure convention, and therefore they all accept a single 2D array. ## Setting up the TransformerBasis -At this point we have an object equipped with the correct methods, so now all we have to do is concatenate the inputs into a unique array and call `fit_transform`, right? +At this point we have an object equipped with the correct methods, so now, all we have to do is concatenate the inputs into a unique array and call `fit_transform`, right? ```{code-cell} ipython3 @@ -115,13 +113,13 @@ except RuntimeError as e: print(repr(e)) ``` -Unfortunately not yet. The problem is that the basis has never interacted with the two separate inputs, and therefore doesn't know which columns of `inp` should be processed by `count_basis` and which by `speed_basis`. +...Unfortunately, not yet. The problem is that the basis has never interacted with the two separate inputs, and therefore doesn't know which columns of `inp` should be processed by `count_basis` and which by `speed_basis`. -There are several ways in which you can provide this information to the basis. The first one is by calling the method `set_input_shape`. +You can provide this information by calling the `set_input_shape` method of the basis. -This can be called before or after the transformer basis is defined. The method extracts and store the array shapes excluding the sample axis (which won't be affected in the concatenation). +This can be called before or after the transformer basis is defined. The method extracts and stores the array shapes excluding the sample axis (which won't be affected in the concatenation). -`set_input_shape` accepts directly the inputs, +`set_input_shape` directly accepts the inputs: ```{code-cell} ipython3 @@ -129,14 +127,14 @@ composite_basis.set_input_shape(counts, speed) out = composite_basis.to_transformer().fit_transform(inp) ``` -If the input is 1D or 2D, the number of columns, +If the input is 1D or 2D, it also accepts the number of columns: ```{code-cell} ipython3 composite_basis.set_input_shape(5, 1) out = composite_basis.to_transformer().fit_transform(inp) ``` -A tuple containing the shapes of all axis other than the first, +A tuple containing the shapes of all the axes other than the first, ```{code-cell} ipython3 composite_basis.set_input_shape((5,), (1,)) @@ -150,7 +148,7 @@ composite_basis.set_input_shape(counts, 1) out = composite_basis.to_transformer().fit_transform(inp) ``` -You can also invert the order and call `to_transform` first and set the input shapes after. +You can also invert the order of operations and call `to_transform` first and then set the input shapes. ```{code-cell} ipython3 trans_bas = composite_basis.to_transformer() @@ -160,12 +158,12 @@ out = trans_bas.fit_transform(inp) :::{note} -If you define a NeMoS basis and call `compute_features` on your inputs, internally, the basis will store the -input shapes, and the `TransformerBasis` will be ready to process without any direct call to `set_input_shape`. +If you define a basis and call `compute_features` on your inputs, internally, it will store its shapes, +and the `TransformerBasis` will be ready to process without any direct call to `set_input_shape`. ::: -If for some reason you will need to provide an input of different shape to the transformer, you must setup the -`TransformerBasis` again. +If for some reason you need to provide an input of different shape to an already set-up transformer, you must reset the +`TransformerBasis` with `set_input_shape`. ```{code-cell} ipython3 @@ -181,5 +179,5 @@ out2 = trans_bas.fit_transform(inp2) ### Learn more -If you want to learn more about basis how to select basis hyperparameters with `sklearn` pipelining and cross-validation, check out [this guide](sklearn-how-to). +If you want to learn more about how to select basis' hyperparameters with `sklearn` pipelining and cross-validation, check out [this how-to guide](sklearn-how-to). diff --git a/docs/index.md b/docs/index.md index 4da5ef06..0b61208c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -20,7 +20,7 @@ For Developers ``` -## __Neural ModelS__ +# __Neural ModelS__ NeMoS (Neural ModelS) is a statistical modeling framework optimized for systems neuroscience and powered by [JAX](https://jax.readthedocs.io/en/latest/). diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index f5635118..3a025925 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -154,7 +154,6 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): self._n_basis_input_ = n_inputs return self - def _check_input_shape_consistency(self, x: NDArray): """Check input consistency across calls.""" # remove sample axis and squeeze @@ -176,6 +175,7 @@ def _check_input_shape_consistency(self, x: NDArray): "different shape, please create a new basis instance." ) + class EvalBasisMixin: """Mixin class for evaluational basis.""" From 15f547e1f31b84a5881c8935bab49cd1c8711903 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 13 Dec 2024 21:10:29 -0500 Subject: [PATCH 21/41] improved svg, fixed docstring --- docs/assets/pipeline.svg | 20 ++++++++++---------- src/nemos/basis/_basis_mixin.py | 6 +++--- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/docs/assets/pipeline.svg b/docs/assets/pipeline.svg index 8c67c7ab..a38b6480 100644 --- a/docs/assets/pipeline.svg +++ b/docs/assets/pipeline.svg @@ -24,12 +24,12 @@ inkscape:deskcolor="#d1d1d1" inkscape:document-units="mm" inkscape:zoom="1.4142136" - inkscape:cx="289.20667" - inkscape:cy="-54.800776" - inkscape:window-width="2488" - inkscape:window-height="1262" + inkscape:cx="206.82873" + inkscape:cy="8.4852811" + inkscape:window-width="1800" + inkscape:window-height="1035" inkscape:window-x="0" - inkscape:window-y="25" + inkscape:window-y="44" inkscape:window-maximized="0" inkscape:current-layer="layer1" /> + y="49.623287" /> Pipeline + x="26.058722" + y="84.056198">Pipeline Basis: Set all basis states. This method corresponds sklearn transformer ``fit``. As fit, it must receive the input and - it must set all basis states, i.e. kernel_ and all the states relative to the input shape. + it must set all basis states, i.e. ``kernel_`` and all the states relative to the input shape. The difference between this method and the transformer ``fit`` is in the expected input structure, where the transformer ``fit`` method requires the inputs to be concatenated in a 2D array, while here each input is provided as a separate time series for each basis element. @@ -323,7 +323,7 @@ def setup_basis(self, *xi: NDArray) -> Basis: Set all basis states. This method corresponds sklearn transformer ``fit``. As fit, it must receive the input and - it must set all basis states, i.e. kernel_ and all the states relative to the input shape. + it must set all basis states, i.e. ``kernel_`` and all the states relative to the input shape. The difference between this method and the transformer ``fit`` is in the expected input structure, where the transformer ``fit`` method requires the inputs to be concatenated in a 2D array, while here each input is provided as a separate time series for each basis element. @@ -539,7 +539,7 @@ def setup_basis(self, *xi: NDArray) -> Basis: Set all basis states. This method corresponds sklearn transformer ``fit``. As fit, it must receive the input and - it must set all basis states, i.e. kernel_ and all the states relative to the input shape. + it must set all basis states, i.e. ``kernel_`` and all the states relative to the input shape. The difference between this method and the transformer ``fit`` is in the expected input structure, where the transformer ``fit`` method requires the inputs to be concatenated in a 2D array, while here each input is provided as a separate time series for each basis element. From 8c8c273c4ab6e388f231ee10a70c670fd4ee8180 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 13 Dec 2024 21:31:17 -0500 Subject: [PATCH 22/41] fixed rendering docstrings --- src/nemos/basis/basis.py | 132 ++++++++++++++------------------------- 1 file changed, 48 insertions(+), 84 deletions(-) diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 758e1f6d..2ea0bd86 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -158,7 +158,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ return super().evaluate_on_grid(n_samples) - @add_docstring("set_input_shape", BSplineBasis) + @add_docstring("set_input_shape", AtomicBasisMixin) def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ Examples @@ -166,18 +166,15 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): >>> import nemos as nmo >>> import numpy as np >>> basis = nmo.basis.BSplineEval(5) - - Configure with an integer input: + >>> # Configure with an integer input: >>> _ = basis.set_input_shape(3) >>> basis.n_output_features 15 - - Configure with a tuple: + >>> # Configure with a tuple: >>> _ = basis.set_input_shape((4, 5)) >>> basis.n_output_features 100 - - Configure with an array: + >>> # Configure with an array: >>> x = np.ones((10, 4, 5)) >>> _ = basis.set_input_shape(x) >>> basis.n_output_features @@ -312,7 +309,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ return super().evaluate_on_grid(n_samples) - @add_docstring("set_input_shape", BSplineBasis) + @add_docstring("set_input_shape", AtomicBasisMixin) def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ Examples @@ -320,18 +317,15 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): >>> import nemos as nmo >>> import numpy as np >>> basis = nmo.basis.BSplineConv(5, 10) - - Configure with an integer input: + >>> # Configure with an integer input: >>> _ = basis.set_input_shape(3) >>> basis.n_output_features 15 - - Configure with a tuple: + >>> # Configure with a tuple: >>> _ = basis.set_input_shape((4, 5)) >>> basis.n_output_features 100 - - Configure with an array: + >>> # Configure with an array: >>> x = np.ones((10, 4, 5)) >>> _ = basis.set_input_shape(x) >>> basis.n_output_features @@ -453,7 +447,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ return super().evaluate_on_grid(n_samples) - @add_docstring("set_input_shape", CyclicBSplineBasis) + @add_docstring("set_input_shape", AtomicBasisMixin) def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ Examples @@ -461,18 +455,15 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): >>> import nemos as nmo >>> import numpy as np >>> basis = nmo.basis.CyclicBSplineEval(5) - - Configure with an integer input: + >>> # Configure with an integer input: >>> _ = basis.set_input_shape(3) >>> basis.n_output_features 15 - - Configure with a tuple: + >>> # Configure with a tuple: >>> _ = basis.set_input_shape((4, 5)) >>> basis.n_output_features 100 - - Configure with an array: + >>> # Configure with an array: >>> x = np.ones((10, 4, 5)) >>> _ = basis.set_input_shape(x) >>> basis.n_output_features @@ -599,7 +590,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ return super().evaluate_on_grid(n_samples) - @add_docstring("set_input_shape", CyclicBSplineBasis) + @add_docstring("set_input_shape", AtomicBasisMixin) def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ Examples @@ -607,18 +598,15 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): >>> import nemos as nmo >>> import numpy as np >>> basis = nmo.basis.CyclicBSplineConv(5, 10) - - Configure with an integer input: + >>> # Configure with an integer input: >>> _ = basis.set_input_shape(3) >>> basis.n_output_features 15 - - Configure with a tuple: + >>> # Configure with a tuple: >>> _ = basis.set_input_shape((4, 5)) >>> basis.n_output_features 100 - - Configure with an array: + >>> # Configure with an array: >>> x = np.ones((10, 4, 5)) >>> _ = basis.set_input_shape(x) >>> basis.n_output_features @@ -764,7 +752,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ return super().evaluate_on_grid(n_samples) - @add_docstring("set_input_shape", MSplineBasis) + @add_docstring("set_input_shape", AtomicBasisMixin) def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ Examples @@ -772,18 +760,15 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): >>> import nemos as nmo >>> import numpy as np >>> basis = nmo.basis.MSplineEval(5) - - Configure with an integer input: + >>> # Configure with an integer input: >>> _ = basis.set_input_shape(3) >>> basis.n_output_features 15 - - Configure with a tuple: + >>> # Configure with a tuple: >>> _ = basis.set_input_shape((4, 5)) >>> basis.n_output_features 100 - - Configure with an array: + >>> # Configure with an array: >>> x = np.ones((10, 4, 5)) >>> _ = basis.set_input_shape(x) >>> basis.n_output_features @@ -934,7 +919,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ return super().evaluate_on_grid(n_samples) - @add_docstring("set_input_shape", MSplineBasis) + @add_docstring("set_input_shape", AtomicBasisMixin) def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ Examples @@ -942,18 +927,15 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): >>> import nemos as nmo >>> import numpy as np >>> basis = nmo.basis.MSplineConv(5, 10) - - Configure with an integer input: + >>> # Configure with an integer input: >>> _ = basis.set_input_shape(3) >>> basis.n_output_features 15 - - Configure with a tuple: + >>> # Configure with a tuple: >>> _ = basis.set_input_shape((4, 5)) >>> basis.n_output_features 100 - - Configure with an array: + >>> # Configure with an array: >>> x = np.ones((10, 4, 5)) >>> _ = basis.set_input_shape(x) >>> basis.n_output_features @@ -1076,7 +1058,7 @@ def split_by_feature( """ return super().split_by_feature(x, axis=axis) - @add_docstring("set_input_shape", RaisedCosineBasisLinear) + @add_docstring("set_input_shape", AtomicBasisMixin) def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ Examples @@ -1084,18 +1066,15 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): >>> import nemos as nmo >>> import numpy as np >>> basis = nmo.basis.RaisedCosineLinearEval(5) - - Configure with an integer input: + >>> # Configure with an integer input: >>> _ = basis.set_input_shape(3) >>> basis.n_output_features 15 - - Configure with a tuple: + >>> # Configure with a tuple: >>> _ = basis.set_input_shape((4, 5)) >>> basis.n_output_features 100 - - Configure with an array: + >>> # Configure with an array: >>> x = np.ones((10, 4, 5)) >>> _ = basis.set_input_shape(x) >>> basis.n_output_features @@ -1223,7 +1202,7 @@ def split_by_feature( """ return super().split_by_feature(x, axis=axis) - @add_docstring("set_input_shape", RaisedCosineBasisLinear) + @add_docstring("set_input_shape", AtomicBasisMixin) def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ Examples @@ -1231,18 +1210,15 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): >>> import nemos as nmo >>> import numpy as np >>> basis = nmo.basis.RaisedCosineLinearConv(5, 10) - - Configure with an integer input: + >>> # Configure with an integer input: >>> _ = basis.set_input_shape(3) >>> basis.n_output_features 15 - - Configure with a tuple: + >>> # Configure with a tuple: >>> _ = basis.set_input_shape((4, 5)) >>> basis.n_output_features 100 - - Configure with an array: + >>> # Configure with an array: >>> x = np.ones((10, 4, 5)) >>> _ = basis.set_input_shape(x) >>> basis.n_output_features @@ -1381,7 +1357,7 @@ def split_by_feature( """ return super().split_by_feature(x, axis=axis) - @add_docstring("set_input_shape", RaisedCosineBasisLog) + @add_docstring("set_input_shape", AtomicBasisMixin) def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ Examples @@ -1389,18 +1365,15 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): >>> import nemos as nmo >>> import numpy as np >>> basis = nmo.basis.RaisedCosineLogEval(5) - - Configure with an integer input: + >>> # Configure with an integer input: >>> _ = basis.set_input_shape(3) >>> basis.n_output_features 15 - - Configure with a tuple: + >>> # Configure with a tuple: >>> _ = basis.set_input_shape((4, 5)) >>> basis.n_output_features 100 - - Configure with an array: + >>> # Configure with an array: >>> x = np.ones((10, 4, 5)) >>> _ = basis.set_input_shape(x) >>> basis.n_output_features @@ -1540,7 +1513,7 @@ def split_by_feature( """ return super().split_by_feature(x, axis=axis) - @add_docstring("set_input_shape", RaisedCosineBasisLog) + @add_docstring("set_input_shape", AtomicBasisMixin) def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ Examples @@ -1548,18 +1521,15 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): >>> import nemos as nmo >>> import numpy as np >>> basis = nmo.basis.RaisedCosineLogConv(5, 10) - - Configure with an integer input: + >>> # Configure with an integer input: >>> _ = basis.set_input_shape(3) >>> basis.n_output_features 15 - - Configure with a tuple: + >>> # Configure with a tuple: >>> _ = basis.set_input_shape((4, 5)) >>> basis.n_output_features 100 - - Configure with an array: + >>> # Configure with an array: >>> x = np.ones((10, 4, 5)) >>> _ = basis.set_input_shape(x) >>> basis.n_output_features @@ -1680,7 +1650,7 @@ def split_by_feature( """ return super().split_by_feature(x, axis=axis) - @add_docstring("set_input_shape", OrthExponentialBasis) + @add_docstring("set_input_shape", AtomicBasisMixin) def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ Examples @@ -1688,18 +1658,15 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): >>> import nemos as nmo >>> import numpy as np >>> basis = nmo.basis.OrthExponentialEval(5, decay_rates=np.arange(1, 6)) - - Configure with an integer input: + >>> # Configure with an integer input: >>> _ = basis.set_input_shape(3) >>> basis.n_output_features 15 - - Configure with a tuple: + >>> # Configure with a tuple: >>> _ = basis.set_input_shape((4, 5)) >>> basis.n_output_features 100 - - Configure with an array: + >>> # Configure with an array: >>> x = np.ones((10, 4, 5)) >>> _ = basis.set_input_shape(x) >>> basis.n_output_features @@ -1827,7 +1794,7 @@ def split_by_feature( """ return super().split_by_feature(x, axis=axis) - @add_docstring("set_input_shape", OrthExponentialBasis) + @add_docstring("set_input_shape", AtomicBasisMixin) def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ Examples @@ -1835,18 +1802,15 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): >>> import nemos as nmo >>> import numpy as np >>> basis = nmo.basis.OrthExponentialConv(5, window_size=10, decay_rates=np.arange(1, 6)) - - Configure with an integer input: + >>> # Configure with an integer input: >>> _ = basis.set_input_shape(3) >>> basis.n_output_features 15 - - Configure with a tuple: + >>> # Configure with a tuple: >>> _ = basis.set_input_shape((4, 5)) >>> basis.n_output_features 100 - - Configure with an array: + >>> # Configure with an array: >>> x = np.ones((10, 4, 5)) >>> _ = basis.set_input_shape(x) >>> basis.n_output_features From 3333eb72fee1bd363ff200310a93be73569c4d4f Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 13 Dec 2024 21:53:39 -0500 Subject: [PATCH 23/41] added image logos for thumbnail --- docs/assets/nemos_sklearn.svg | 119 ++++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 docs/assets/nemos_sklearn.svg diff --git a/docs/assets/nemos_sklearn.svg b/docs/assets/nemos_sklearn.svg new file mode 100644 index 00000000..8ea0a3e3 --- /dev/null +++ b/docs/assets/nemos_sklearn.svg @@ -0,0 +1,119 @@ + + + +scikit From 0c3fc32340acbbe37292328cd7d3237868c7597e Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Sun, 15 Dec 2024 12:23:09 -0500 Subject: [PATCH 24/41] use a generator for unpacking in a mem efficient way --- src/nemos/basis/_transformer_basis.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index 16c01e20..56c04a06 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -99,13 +99,13 @@ def _unpack_inputs(self, X: FeatureMatrix) -> List: """ n_samples = X.shape[0] - out = [ + out = ( np.reshape(X[:, cc : cc + n_input], (n_samples, *bas._input_shape_)) for i, (bas, n_input) in enumerate( zip(self._list_components(), self._n_basis_input_) ) for cc in [sum(self._n_basis_input_[:i])] - ] + ) return out def fit(self, X: FeatureMatrix, y=None): From 5998a2d3f03d017cfdf0aabd94ae469ffb4f37f7 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Sun, 15 Dec 2024 12:23:53 -0500 Subject: [PATCH 25/41] fixed typing --- src/nemos/basis/_transformer_basis.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index 56c04a06..67063e28 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -1,7 +1,7 @@ from __future__ import annotations import copy -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, Generator import numpy as np @@ -81,7 +81,7 @@ def basis(self): def basis(self, basis): self._basis = basis - def _unpack_inputs(self, X: FeatureMatrix) -> List: + def _unpack_inputs(self, X: FeatureMatrix) -> Generator: """Unpack inputs. Unpack horizontally stacked inputs using slicing. This works gracefully with ``pynapple``, From 3d91df13062d928675adffc19cf71c5353920d59 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Sun, 15 Dec 2024 12:39:47 -0500 Subject: [PATCH 26/41] return generator for iterating over components --- src/nemos/basis/_basis_mixin.py | 24 ++++++++++++++---------- src/nemos/basis/_transformer_basis.py | 2 +- tests/test_basis.py | 11 ++++++----- 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index d0c80f81..21ade27d 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -7,7 +7,8 @@ import inspect import warnings from functools import wraps -from typing import TYPE_CHECKING, Optional, Tuple, Union +from itertools import chain +from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union import numpy as np from numpy.typing import ArrayLike, NDArray @@ -88,8 +89,8 @@ def __sklearn_clone__(self) -> Basis: setattr(klass, attr_name, getattr(self, attr_name)) return klass - def _list_components(self): - """List all basis components. + def _iterate_over_components(self) -> Generator: + """Return a generator that iterates over all basis components. For atomic bases, the list is just [self]. @@ -98,7 +99,7 @@ def _list_components(self): A list with the basis components. """ - return [self] + return (x for x in [self]) def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ @@ -512,8 +513,8 @@ def __init__(self, basis1: Basis, basis2: Basis): self.basis2._parent = self shapes = ( - *(bas1._input_shape_ for bas1 in basis1._list_components()), - *(bas2._input_shape_ for bas2 in basis2._list_components()), + *(bas1._input_shape_ for bas1 in basis1._iterate_over_components()), + *(bas2._input_shape_ for bas2 in basis2._iterate_over_components()), ) # if all bases where set, then set input for composition. set_bases = (s is not None for s in shapes) @@ -599,17 +600,20 @@ def basis2(self): def basis2(self, bas: Basis): self._basis2 = bas - def _list_components(self): - """List all basis components. + def _iterate_over_components(self): + """Return a generator that iterates over all basis components. - Reimplements the default behavior by iteratively calling _list_components of the + Reimplements the default behavior by iteratively calling _iterate_over_components of the elements. Returns ------- A list with all 1d basis components. """ - return self._basis1._list_components() + self._basis2._list_components() + return chain( + self._basis1._iterate_over_components(), + self._basis2._iterate_over_components(), + ) @set_input_shape_state def __sklearn_clone__(self) -> Basis: diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index 67063e28..4420eb40 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -102,7 +102,7 @@ def _unpack_inputs(self, X: FeatureMatrix) -> Generator: out = ( np.reshape(X[:, cc : cc + n_input], (n_samples, *bas._input_shape_)) for i, (bas, n_input) in enumerate( - zip(self._list_components(), self._n_basis_input_) + zip(self._iterate_over_components(), self._n_basis_input_) ) for cc in [sum(self._n_basis_input_[:i])] ) diff --git a/tests/test_basis.py b/tests/test_basis.py index 770bce35..1031bc14 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -1331,14 +1331,14 @@ def test_set_input_value_types(self, inp_shape, expectation, cls): @pytest.mark.parametrize( "mode, kwargs", [("eval", {}), ("conv", {"window_size": 6})] ) - def test_list_component(self, mode, kwargs, cls): + def test_iterate_over_component(self, mode, kwargs, cls): basis_obj = cls[mode]( n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5), ) - out = basis_obj._list_components() + out = tuple(basis_obj._iterate_over_components()) assert len(out) == 1 assert id(out[0]) == id(basis_obj) @@ -1974,7 +1974,9 @@ class TestAdditiveBasis(CombinedBasis): @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) - def test_list_component(self, basis_a, basis_b, basis_class_specific_params): + def test_iterate_over_component( + self, basis_a, basis_b, basis_class_specific_params + ): basis_a_obj = self.instantiate_basis( 5, basis_a, basis_class_specific_params, window_size=10 ) @@ -1982,8 +1984,7 @@ def test_list_component(self, basis_a, basis_b, basis_class_specific_params): 6, basis_b, basis_class_specific_params, window_size=10 ) add = basis_a_obj + basis_b_obj - out = add._list_components() - + out = tuple(add._iterate_over_components()) assert len(out) == add._n_input_dimensionality def get_ids(bas): From 28599c8c6afa8557deb04f9aa565ee027cec6cf0 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Sun, 15 Dec 2024 12:50:47 -0500 Subject: [PATCH 27/41] change docstrings --- src/nemos/basis/_basis_mixin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 21ade27d..79c882c2 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -96,7 +96,7 @@ def _iterate_over_components(self) -> Generator: Returns ------- - A list with the basis components. + A generator returning self, it will be chained for composite bases. """ return (x for x in [self]) @@ -608,7 +608,7 @@ def _iterate_over_components(self): Returns ------- - A list with all 1d basis components. + A generator looping on each individual input. """ return chain( self._basis1._iterate_over_components(), From bb5363e1df5432583c685fd3f994e87e85645bb3 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Sun, 15 Dec 2024 12:50:58 -0500 Subject: [PATCH 28/41] change docstrings --- src/nemos/basis/_basis_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 79c882c2..8e212cbf 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -96,7 +96,7 @@ def _iterate_over_components(self) -> Generator: Returns ------- - A generator returning self, it will be chained for composite bases. + A generator returning self, it will be chained in composite bases. """ return (x for x in [self]) From 61a60918baa475a0c7a3fb64cd872ebc0143bad0 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 16 Dec 2024 14:43:37 -0500 Subject: [PATCH 29/41] add test and lint --- tests/test_basis.py | 46 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index 1031bc14..4a74623d 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -153,6 +153,15 @@ def method(self): pass assert CustomSubClass().method.__doc__ == "My extra text.\nMy custom method." + with pytest.raises(AttributeError, match="CustomClass has no attribute"): + + class CustomSubClass2(CustomClass): + @custom_add_docstring("unknown", cls=CustomClass) + def method(self): + """My custom method.""" + pass + + CustomSubClass2() @pytest.mark.parametrize( @@ -242,18 +251,34 @@ def test_expected_output_compute_features(basis_instance, super_class): ), OrthExponentialBasis, ), + ( + basis.OrthExponentialConv( + 10, decay_rates=np.arange(1, 11), window_size=12, label="a" + ) + * basis.RaisedCosineLogConv(10, window_size=11, label="b"), + OrthExponentialBasis, + ), + ( + basis.OrthExponentialConv( + 10, decay_rates=np.arange(1, 11), window_size=12, label="a" + ) + + basis.RaisedCosineLogConv(10, window_size=11, label="b"), + OrthExponentialBasis, + ), ], ) def test_expected_output_split_by_feature(basis_instance, super_class): - x = super_class.compute_features(basis_instance, np.linspace(0, 1, 100)) + inp = [np.linspace(0, 1, 100)] * basis_instance._n_input_dimensionality + x = super_class.compute_features(basis_instance, *inp) xdict = super_class.split_by_feature(basis_instance, x) xxdict = basis_instance.split_by_feature(x) assert xdict.keys() == xxdict.keys() - xx = xxdict["label"] - x = xdict["label"] - nans = np.isnan(x.sum(axis=(1, 2))) - assert np.all(np.isnan(xx[nans])) - np.testing.assert_array_equal(xx[~nans], x[~nans]) + for k in xdict.keys(): + xx = xxdict[k] + x = xdict[k] + nans = np.isnan(x.sum(axis=(1, 2))) + assert np.all(np.isnan(xx[nans])) + np.testing.assert_array_equal(xx[~nans], x[~nans]) @pytest.mark.parametrize( @@ -1580,6 +1605,15 @@ def test_minimum_number_of_basis_required_is_matched( n_basis_funcs=n_basis_funcs, order=order, **kwargs ) basis_obj.compute_features(np.linspace(0, 1, 10)) + + # test the setter valuerror + if (order > 1) & (n_basis_funcs > 1): + basis_obj = self.cls[mode](n_basis_funcs=20, order=order, **kwargs) + with pytest.raises( + ValueError, + match=rf"{self.cls[mode].__name__} `order` parameter cannot be larger than", + ): + basis_obj.n_basis_funcs = n_basis_funcs else: basis_obj = self.cls[mode]( n_basis_funcs=n_basis_funcs, order=order, **kwargs From 8bff762cc4e6065a24f5d97f2d6d6c082e92022b Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 16 Dec 2024 15:08:07 -0500 Subject: [PATCH 30/41] linted --- src/nemos/basis/_basis_mixin.py | 12 ++++------ tests/test_basis.py | 41 ++++++++++++++++++++++++++++++--- 2 files changed, 42 insertions(+), 11 deletions(-) diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 8e212cbf..44b024db 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -307,11 +307,7 @@ def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor): structure: a single (X, y) pair for the transformer, a number of time series for the Basis. """ - if self.kernel_ is None: - raise ValueError( - "You must call `setup_basis` before `_compute_features`! " - "Convolution kernel is not set." - ) + self._check_has_kernel() # before calling the convolve, check that the input matches # the expectation. We can check xi[0] only, since convolution # is applied at the end of the recursion on the 1D basis, ensuring len(xi) == 1. @@ -457,8 +453,8 @@ def _check_convolution_kwargs(conv_kwargs: dict): def _check_has_kernel(self) -> None: """Check that the kernel is pre-computed.""" if self.kernel_ is None: - raise ValueError( - "You must call `_set_kernel` before `_compute_features` for Conv basis." + raise RuntimeError( + "You must call `setup_basis` before `_compute_features` for Conv basis." ) @@ -517,7 +513,7 @@ def __init__(self, basis1: Basis, basis2: Basis): *(bas2._input_shape_ for bas2 in basis2._iterate_over_components()), ) # if all bases where set, then set input for composition. - set_bases = (s is not None for s in shapes) + set_bases = [s is not None for s in shapes] if all(set_bases): # pass down the input shapes diff --git a/tests/test_basis.py b/tests/test_basis.py index 4a74623d..e8759ace 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -1262,7 +1262,7 @@ def test_transform_fails(self, cls): n_basis_funcs=5, window_size=5, **extra_decay_rates(cls["conv"], 5) ) with pytest.raises( - ValueError, match="You must call `setup_basis` before `_compute_features`" + RuntimeError, match="You must call `setup_basis` before `_compute_features`" ): bas._compute_features(np.linspace(0, 1, 10)) @@ -2245,6 +2245,41 @@ def test_number_of_required_inputs_compute_features( with expectation: basis_obj.compute_features(*inputs) + @pytest.mark.parametrize("basis_a", list_all_basis_classes()) + @pytest.mark.parametrize("basis_b", list_all_basis_classes()) + @pytest.mark.parametrize("n_basis_a", [5]) + @pytest.mark.parametrize("n_basis_b", [6]) + @pytest.mark.parametrize("window_size", [10]) + def test_warn_partial_setup( + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, + ): + basis_a_obj = self.instantiate_basis( + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size + ) + basis_b_obj = self.instantiate_basis( + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size + ) + + basis_a_obj.set_input_shape(*([1] * basis_a_obj._n_input_dimensionality)) + with pytest.warns(UserWarning, match="Only some of the basis where"): + basis_a_obj + basis_b_obj + + # check that if both set addition is fine + basis_b_obj.set_input_shape(*([1] * basis_b_obj._n_input_dimensionality)) + basis_a_obj + basis_b_obj + + basis_a_obj = self.instantiate_basis( + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size + ) + with pytest.warns(UserWarning, match="Only some of the basis where"): + basis_a_obj + basis_b_obj + @pytest.mark.parametrize("sample_size", [11, 20]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @@ -2696,7 +2731,7 @@ def test_transform_fails( context = does_not_raise() else: context = pytest.raises( - ValueError, + RuntimeError, match="You must call `setup_basis` before `_compute_features`", ) with context: @@ -3624,7 +3659,7 @@ def test_transform_fails( context = does_not_raise() else: context = pytest.raises( - ValueError, + RuntimeError, match="You must call `setup_basis` before `_compute_features`", ) with context: From f0cb04049edb20a06717ff933d65b54deb7a2163 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 16 Dec 2024 15:51:11 -0500 Subject: [PATCH 31/41] added a test for window size after init for orth exp --- tests/test_basis.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_basis.py b/tests/test_basis.py index e8759ace..2cace735 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -1714,6 +1714,16 @@ def test_window_size_at_init(self, window_size, n_basis, expectation): with expectation: self.cls["conv"](n_basis, decay_rates=decay_rates, window_size=window_size) + def test_check_window_size_after_init(self): + decay_rates = np.asarray(np.arange(1, 5 + 1), dtype=float) + expectation = pytest.raises( + ValueError, + match="OrthExponentialConv basis requires at least a window_size", + ) + bas = self.cls["conv"](5, decay_rates=decay_rates, window_size=10) + with expectation: + bas.window_size = 4 + @pytest.mark.parametrize( "window_size, n_basis, expectation", [ From 090fa40761f8fd9b34774ccad486c11d06fba103 Mon Sep 17 00:00:00 2001 From: Edoardo Balzani Date: Wed, 18 Dec 2024 09:31:18 -0500 Subject: [PATCH 32/41] Update src/nemos/basis/_basis.py Co-authored-by: Sarah Jo Venditto --- src/nemos/basis/_basis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index b6a66a56..94e8b910 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -163,7 +163,7 @@ def n_output_features(self) -> int | None: ----- The number of output features can be determined only when the number of inputs provided to the basis is known. Therefore, before the first call to ``compute_features``, - this property will return ``None``. After that call, ``n_output_features`` will be available. + this property will return ``None``. After that call, or after setting the input shape with ``set_input_shape``, ``n_output_features`` will be available. """ if self._n_basis_input_ is not None: return self.n_basis_funcs * self._n_basis_input_[0] From 65c061889c9b16a2dc867bb01f3ee9e635ebf15f Mon Sep 17 00:00:00 2001 From: Edoardo Balzani Date: Wed, 18 Dec 2024 09:31:30 -0500 Subject: [PATCH 33/41] Update src/nemos/basis/_basis_mixin.py Co-authored-by: Sarah Jo Venditto --- src/nemos/basis/_basis_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 44b024db..c879d275 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -173,7 +173,7 @@ def _check_input_shape_consistency(self, x: NDArray): f" But got: {x.shape}.\n\n" "Note: The number of samples (`n_samples`) can vary between calls of `compute_features`, " "but all other dimensions must remain the same. If you need to process inputs with a " - "different shape, please create a new basis instance." + "different shape, please create a new basis instance, or set a new input shape by calling `set_input_shape`." ) From afdd3922d7d7bf342753ad48cd9021e2e70e179f Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 18 Dec 2024 09:34:36 -0500 Subject: [PATCH 34/41] fixes docs --- .../how_to_guide/plot_06_sklearn_pipeline_cv_demo.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md index 2afa68e8..1051833a 100644 --- a/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md +++ b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md @@ -534,12 +534,12 @@ param_grid = dict( glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6), transformerbasis__n_basis_funcs=(3, 5, 10, 20, 100), transformerbasis___basis=( - nmo.basis.RaisedCosineLinearEval(5), - nmo.basis.RaisedCosineLinearEval(10), - nmo.basis.RaisedCosineLogEval(5), - nmo.basis.RaisedCosineLogEval(10), - nmo.basis.MSplineEval(5), - nmo.basis.MSplineEval(10), + nmo.basis.RaisedCosineLinearEval(5).set_input_shape(1), + nmo.basis.RaisedCosineLinearEval(10).set_input_shape(1), + nmo.basis.RaisedCosineLogEval(5).set_input_shape(1), + nmo.basis.RaisedCosineLogEval(10).set_input_shape(1), + nmo.basis.MSplineEval(5).set_input_shape(1), + nmo.basis.MSplineEval(10).set_input_shape(1), ), ) ``` From 2ade4bde411330c0685091cf01bb15fe6e8c88d6 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 18 Dec 2024 09:38:15 -0500 Subject: [PATCH 35/41] linted --- src/nemos/basis/_basis.py | 3 ++- src/nemos/basis/_basis_mixin.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index 94e8b910..e2f4a762 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -163,7 +163,8 @@ def n_output_features(self) -> int | None: ----- The number of output features can be determined only when the number of inputs provided to the basis is known. Therefore, before the first call to ``compute_features``, - this property will return ``None``. After that call, or after setting the input shape with ``set_input_shape``, ``n_output_features`` will be available. + this property will return ``None``. After that call, or after setting the input shape with + ``set_input_shape``, ``n_output_features`` will be available. """ if self._n_basis_input_ is not None: return self.n_basis_funcs * self._n_basis_input_[0] diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index c879d275..b708ec3a 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -173,7 +173,8 @@ def _check_input_shape_consistency(self, x: NDArray): f" But got: {x.shape}.\n\n" "Note: The number of samples (`n_samples`) can vary between calls of `compute_features`, " "but all other dimensions must remain the same. If you need to process inputs with a " - "different shape, please create a new basis instance, or set a new input shape by calling `set_input_shape`." + "different shape, please create a new basis instance, or set a new input shape by calling " + "`set_input_shape`." ) From cbf014b799a7dd3d944ae3da0d0ef1b03caea500 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 18 Dec 2024 12:53:04 -0500 Subject: [PATCH 36/41] add ignore links --- .readthedocs.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 37e28a3a..5eca7438 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -15,7 +15,7 @@ build: - gem install html-proofer -v ">= 5.0.9" # Ensure version >= 5.0.9 post_build: # Check everything except 403s and a jneurosci, which returns 404 but the link works when clicking. - - htmlproofer $READTHEDOCS_OUTPUT/html --checks Links,Scripts,Images --ignore-urls "https://fonts.gstatic.com,https://celltypes.brain-map.org/experiment/electrophysiology/478498617,https://www.jneurosci.org/content/25/47/11003" --assume-extension --check-external-hash --ignore-status-codes 403,0 --ignore-files "/.+\/_static\/.+/","/.+\/stubs\/.+/","/.+\/tutorials/plot_02_head_direction.+/" + - htmlproofer $READTHEDOCS_OUTPUT/html --checks Links,Scripts,Images --ignore-urls "https://fonts.gstatic.com,https://celltypes.brain-map.org/experiment/electrophysiology/478498617,https://www.jneurosci.org/content/25/47/11003,https://www.nature.com/articles/s41467-017-01908-3,https://doi.org/10.1038/s41467-017-01908-3" --assume-extension --check-external-hash --ignore-status-codes 403,0 --ignore-files "/.+\/_static\/.+/","/.+\/stubs\/.+/","/.+\/tutorials/plot_02_head_direction.+/" # The auto-generated animation doesn't have a alt or src/srcset; I am able to ignore missing alt, but I cannot work around a missing src/srcset # therefore for this file I am not checking the figures. - htmlproofer $READTHEDOCS_OUTPUT/html/tutorials/plot_02_head_direction.html --checks Links,Scripts --ignore-urls "https://www.jneurosci.org/content/25/47/11003" From 0038600875b056558c142c5d778e03290cd4f835 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 18 Dec 2024 14:42:51 -0500 Subject: [PATCH 37/41] first round of edits --- .../how_to_guide/plot_05_transformer_basis.md | 71 +++++++++---------- .../plot_06_sklearn_pipeline_cv_demo.md | 4 ++ 2 files changed, 38 insertions(+), 37 deletions(-) diff --git a/docs/how_to_guide/plot_05_transformer_basis.md b/docs/how_to_guide/plot_05_transformer_basis.md index 9b8fe6b6..aed91e54 100644 --- a/docs/how_to_guide/plot_05_transformer_basis.md +++ b/docs/how_to_guide/plot_05_transformer_basis.md @@ -11,7 +11,7 @@ kernelspec: name: python3 --- -# Converting NeMoS Bases To scikit-learn Transformers +# Using bases as scikit-learn transformers (tansformer-vs-nemos-basis)= ## scikit-learn Transformers and NeMoS Basis @@ -19,22 +19,26 @@ kernelspec: `scikit-learn` is a powerful machine learning library that provides advanced tools for creating data analysis pipelines, from input transformations to model fitting and cross-validation. All of `scikit-learn`'s machinery relies on strict assumptions about input structure. In particular, all `scikit-learn` -objects require inputs as arrays of at most two dimensions, where the first dimension represents the time (or samples) +objects require inputs to be arrays of at most two dimensions, where the first dimension represents the time (or samples) axis, and the second dimension represents features. While this may feel rigid, it enables transformations to be seamlessly chained together, greatly simplifying the process of building stable, complex pipelines. -On the other hand, `NeMoS` takes a different approach to feature construction. `NeMoS`' bases are composable constructors that allow for more flexibility in the required input structure. -Depending on the basis type, it can accept one or more input arrays or `pynapple` time series data, each of which can take any shape as long as the time (or sample) axis is the first of each array; -`NeMoS` design favours object composability: one can combine any two or more bases to compute complex features, with a user-friendly interface that can accept a separate array/time series for each input type (e.g., an array with the spike counts, an array for the animal's position, etc.). +They can accept arrays or `pynapple` time series data, which can take any shape as long as the time (or sample) axis is the first of each array. +Furthermore, `NeMoS` design favours object composability: one can combine bases into [`CompositeBasis`](composing_basis_function) objects to compute complex features, with a user-friendly interface that can accept a separate array/time series for each input type (e.g., an array with the spike counts, an array for the animal's position, etc.). Both approaches to data transformation are valuable and each has its own advantages. Wouldn't it be great if one could combine the two? Well, this is what NeMoS `TransformerBasis` is for! ## From Basis to TransformerBasis +:::{admonition} Composite Basis +:class: note -With NeMoS, you can easily create a basis accepting two inputs. Let's assume that we want to process neural activity stored in a 2-dimensional spike count array of shape `(n_samples, n_neurons)` and a second array containing the speed of an animal, with shape `(n_samples,)`. +To learn more on composite basis, take a look at [this note](composing_basis_function). +::: + +With NeMoS, you can easily create a basis which accepts two inputs. Let's assume that we want to process neural activity stored in a 2-dimensional spike count array of shape `(n_samples, n_neurons)` and a second array containing the speed of an animal, with shape `(n_samples,)`. ```{code-cell} ipython3 import numpy as np @@ -58,7 +62,7 @@ X = composite_basis.compute_features(counts, speed) ### Converting NeMoS `Basis` to a transformer Now, imagine that we want to use this basis as a step in a `scikit-learn` pipeline. -In this standard (for NeMoS) form, it would not be possible as the `composite_basis` object requires two inputs. We need to convert it first into a compliant `scikit-learn` transformer. This can be achieved through the [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) wrapper class. +In this standard (for NeMoS) form, it would not be possible as the `composite_basis` object requires two inputs. We need to convert it first into a `scikit-learn`-compliant transformer. This can be achieved through the [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) wrapper class. Instantiating a [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) can be done either by using the constructor directly or with [`Basis.to_transformer()`](nemos.basis._basis.Basis.to_transformer): @@ -98,55 +102,45 @@ As with any `sckit-learn` transformer, the `TransformerBasis` implements `fit`, At this point we have an object equipped with the correct methods, so now, all we have to do is concatenate the inputs into a unique array and call `fit_transform`, right? ```{code-cell} ipython3 +:tags: [raises-exception] # reinstantiate the basis transformer for illustration porpuses composite_basis = counts_basis + speed_basis trans_bas = (composite_basis).to_transformer() - # concatenate the inputs inp = np.concatenate([counts, speed[:, np.newaxis]], axis=1) print(inp.shape) - -try: - trans_bas.fit_transform(inp) -except RuntimeError as e: - print(repr(e)) +trans_bas.fit_transform(inp) ``` -...Unfortunately, not yet. The problem is that the basis has never interacted with the two separate inputs, and therefore doesn't know which columns of `inp` should be processed by `count_basis` and which by `speed_basis`. +...Unfortunately, not yet. The problem is that the basis doesn't know which columns of `inp` should be processed by `count_basis` and which by `speed_basis`. You can provide this information by calling the `set_input_shape` method of the basis. -This can be called before or after the transformer basis is defined. The method extracts and stores the array shapes excluding the sample axis (which won't be affected in the concatenation). - -`set_input_shape` directly accepts the inputs: - -```{code-cell} ipython3 - -composite_basis.set_input_shape(counts, speed) -out = composite_basis.to_transformer().fit_transform(inp) -``` +This can be called before or after the transformer basis is defined. The method extracts and stores the number of columns for each input. There are multiple ways to call this method: -If the input is 1D or 2D, it also accepts the number of columns: -```{code-cell} ipython3 +- It directly accepts the input: `composite_basis.set_input_shape(counts, speed)`. +- If the input is 1D or 2D, it also accepts the number of columns: `composite_basis.set_input_shape(5, 1)`. +- A tuple containing the shapes of all except the first: `composite_basis.set_input_shape((5,), (1,))`. +- A mix of the above methods: `composite_basis.set_input_shape(counts, 1)`. -composite_basis.set_input_shape(5, 1) -out = composite_basis.to_transformer().fit_transform(inp) -``` +:::{note} -A tuple containing the shapes of all the axes other than the first, -```{code-cell} ipython3 +Note that what `set_input_shapes` requires are the dimensions of the input stimuli, with the exception of the sample +axis. For example, if the input is a 4D tensor, one needs to provide the last 3 dimensions: -composite_basis.set_input_shape((5,), (1,)) -out = composite_basis.to_transformer().fit_transform(inp) -``` +```{code} ipython3 +# generate a 4D input +x = np.random.randn(10, 3, 2, 1) -Or a mix of the above. -```{code-cell} ipython3 +# define and setup the basis +basis = nmo.basis.BSplineEval(5).set_input_shape((3, 2, 1)) -composite_basis.set_input_shape(counts, 1) -out = composite_basis.to_transformer().fit_transform(inp) +X = basis.to_transformer().transform( + x.reshape(10, -1) # reshape to 2D +) ``` +::: You can also invert the order of operations and call `to_transform` first and then set the input shapes. ```{code-cell} ipython3 @@ -162,8 +156,11 @@ If you define a basis and call `compute_features` on your inputs, internally, it and the `TransformerBasis` will be ready to process without any direct call to `set_input_shape`. ::: +:::{warning} + If for some reason you need to provide an input of different shape to an already set-up transformer, you must reset the `TransformerBasis` with `set_input_shape`. +::: ```{code-cell} ipython3 diff --git a/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md index 1051833a..04d1092b 100644 --- a/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md +++ b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md @@ -529,7 +529,11 @@ The plot confirms that the firing rate distribution is accurately captured by ou :::{warning} Please note that because it would lead to unexpected behavior, mixing the two ways of defining values for the parameter grid is not allowed. The following would lead to an error: + + + ```{code} ipython + param_grid = dict( glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6), transformerbasis__n_basis_funcs=(3, 5, 10, 20, 100), From efc138fb2182ac3c229eb044ae28737c3bf6038d Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 18 Dec 2024 15:34:43 -0500 Subject: [PATCH 38/41] fixes from PR --- src/nemos/basis/_basis.py | 11 +++++------ src/nemos/basis/_basis_mixin.py | 17 +++++----------- src/nemos/basis/basis.py | 16 --------------- tests/test_basis.py | 35 --------------------------------- 4 files changed, 10 insertions(+), 69 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index e2f4a762..06af0f9f 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -133,10 +133,9 @@ class Basis(Base, abc.ABC, BasisTransformerMixin): def __init__( self, - mode: Literal["eval", "conv"] = "eval", + mode: Literal["eval", "conv", "composite"] = "eval", label: Optional[str] = None, ) -> None: - self._n_basis_funcs = getattr(self, "_n_basis_funcs", None) self._n_input_dimensionality = getattr(self, "_n_input_dimensionality", 0) self._mode = mode @@ -147,8 +146,8 @@ def __init__( self._label = str(label) # specified only after inputs/input shapes are provided - self._n_basis_input_ = getattr(self, "_n_basis_input_", None) - self._input_shape_ = getattr(self, "_input_shape_", None) + self._n_basis_input_ = None + self._input_shape_ = None # initialize parent to None. This should not end in "_" because it is # a permanent property of a basis, defined at composite basis init @@ -743,7 +742,7 @@ class AdditiveBasis(CompositeBasisMixin, Basis): def __init__(self, basis1: Basis, basis2: Basis) -> None: CompositeBasisMixin.__init__(self, basis1, basis2) - Basis.__init__(self, mode="eval") + Basis.__init__(self, mode="composite") self._label = "(" + basis1.label + " + " + basis2.label + ")" self._n_input_dimensionality = ( @@ -1154,7 +1153,7 @@ class MultiplicativeBasis(CompositeBasisMixin, Basis): def __init__(self, basis1: Basis, basis2: Basis) -> None: CompositeBasisMixin.__init__(self, basis1, basis2) - Basis.__init__(self, mode="eval") + Basis.__init__(self, mode="composite") self._label = "(" + basis1.label + " * " + basis2.label + ")" self._n_input_dimensionality = ( basis1._n_input_dimensionality + basis2._n_input_dimensionality diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index b708ec3a..3b40b4b6 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -5,7 +5,6 @@ import abc import copy import inspect -import warnings from functools import wraps from itertools import chain from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union @@ -21,7 +20,9 @@ from ._basis import Basis -def set_input_shape_state(method): +def set_input_shape_state( + method, states: Tuple[str] = ("_n_basis_input_", "_input_shape_") +): """ Decorator to preserve input shape-related attributes during method execution. @@ -36,6 +37,7 @@ def set_input_shape_state(method): method : The method to be wrapped. This method is expected to return an object (`klass`) that requires the `_n_basis_input_` and `_input_shape_` attributes. + attr_list Returns ------- @@ -60,7 +62,7 @@ def set_input_shape_state(method): @wraps(method) def wrapper(self, *args, **kwargs): klass: Basis = method(self, *args, **kwargs) - for attr_name in ["_n_basis_input_", "_input_shape_"]: + for attr_name in states: setattr(klass, attr_name, getattr(self, attr_name)) return klass @@ -84,9 +86,6 @@ def __sklearn_clone__(self) -> Basis: cross-validation unusable. """ klass = self.__class__(**self.get_params()) - - for attr_name in ["_n_basis_input_", "_input_shape_"]: - setattr(klass, attr_name, getattr(self, attr_name)) return klass def _iterate_over_components(self) -> Generator: @@ -519,12 +518,6 @@ def __init__(self, basis1: Basis, basis2: Basis): if all(set_bases): # pass down the input shapes self.set_input_shape(*shapes) - elif any(set_bases): - warnings.warn( - "Only some of the basis where initialized with `set_input_shape`, " - "please initialize the composite basis before computing features.", - category=UserWarning, - ) @property @abc.abstractmethod diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 2ea0bd86..8601a101 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -1830,19 +1830,3 @@ def _check_window_size(self, window_size: int): f"of basis functions. window_size is {window_size}, n_basis_funcs while" f"is {self.n_basis_funcs}." ) - - def set_kernel(self): - try: - super().set_kernel() - except ValueError as e: - if "OrthExponentialBasis requires at least as many" in str(e): - raise ValueError( - "Cannot set the kernels for OrthExponentialBasis when `window_size` is smaller " - "than `n_basis_funcs.\n" - "Please, increase the window size or reduce the number of basis functions. " - f"Current `window_size` is {self.window_size}, while `n_basis_funcs` is " - f"{self.n_basis_funcs}." - ) - else: - raise e - return self diff --git a/tests/test_basis.py b/tests/test_basis.py index 2cace735..e29b6cdd 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -2255,41 +2255,6 @@ def test_number_of_required_inputs_compute_features( with expectation: basis_obj.compute_features(*inputs) - @pytest.mark.parametrize("basis_a", list_all_basis_classes()) - @pytest.mark.parametrize("basis_b", list_all_basis_classes()) - @pytest.mark.parametrize("n_basis_a", [5]) - @pytest.mark.parametrize("n_basis_b", [6]) - @pytest.mark.parametrize("window_size", [10]) - def test_warn_partial_setup( - self, - n_basis_a, - n_basis_b, - basis_a, - basis_b, - window_size, - basis_class_specific_params, - ): - basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, basis_class_specific_params, window_size=window_size - ) - basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, basis_class_specific_params, window_size=window_size - ) - - basis_a_obj.set_input_shape(*([1] * basis_a_obj._n_input_dimensionality)) - with pytest.warns(UserWarning, match="Only some of the basis where"): - basis_a_obj + basis_b_obj - - # check that if both set addition is fine - basis_b_obj.set_input_shape(*([1] * basis_b_obj._n_input_dimensionality)) - basis_a_obj + basis_b_obj - - basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, basis_class_specific_params, window_size=window_size - ) - with pytest.warns(UserWarning, match="Only some of the basis where"): - basis_a_obj + basis_b_obj - @pytest.mark.parametrize("sample_size", [11, 20]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) From 734ea9d7771e9d71321159233713c95cf57c5264 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 18 Dec 2024 16:00:05 -0500 Subject: [PATCH 39/41] removed properties --- src/nemos/basis/_basis.py | 36 ++++++++++++++--------------- src/nemos/basis/_basis_mixin.py | 40 ++++++++++----------------------- tests/test_basis.py | 22 ++++++++---------- 3 files changed, 39 insertions(+), 59 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index 06af0f9f..6e9cbbca 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -761,8 +761,8 @@ def n_basis_funcs(self): @property def n_output_features(self): - out1 = getattr(self._basis1, "n_output_features", None) - out2 = getattr(self._basis2, "n_output_features", None) + out1 = getattr(self.basis1, "n_output_features", None) + out2 = getattr(self.basis2, "n_output_features", None) if out1 is None or out2 is None: return None return out1 + out2 @@ -829,8 +829,8 @@ def _evaluate(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatri """ X = np.hstack( ( - self._basis1._evaluate(*xi[: self._basis1._n_input_dimensionality]), - self._basis2._evaluate(*xi[self._basis1._n_input_dimensionality :]), + self.basis1._evaluate(*xi[: self.basis1._n_input_dimensionality]), + self.basis2._evaluate(*xi[self.basis1._n_input_dimensionality :]), ) ) return X @@ -878,11 +878,11 @@ def _compute_features( hstack_pynapple = support_pynapple(conv_type="numpy")(np.hstack) X = hstack_pynapple( ( - self._basis1._compute_features( - *xi[: self._basis1._n_input_dimensionality] + self.basis1._compute_features( + *xi[: self.basis1._n_input_dimensionality] ), - self._basis2._compute_features( - *xi[self._basis1._n_input_dimensionality :] + self.basis2._compute_features( + *xi[self.basis1._n_input_dimensionality :] ), ), ) @@ -1098,13 +1098,13 @@ def _get_feature_slicing( # If the instance is of AdditiveBasis type, handle slicing for the additive components - split_dict, start_slice = self._basis1._get_feature_slicing( - n_inputs[: len(self._basis1._n_basis_input_)], + split_dict, start_slice = self.basis1._get_feature_slicing( + n_inputs[: len(self.basis1._n_basis_input_)], start_slice, split_by_input=split_by_input, ) - sp2, start_slice = self._basis2._get_feature_slicing( - n_inputs[len(self._basis1._n_basis_input_) :], + sp2, start_slice = self.basis2._get_feature_slicing( + n_inputs[len(self.basis1._n_basis_input_) :], start_slice, split_by_input=split_by_input, ) @@ -1171,8 +1171,8 @@ def n_basis_funcs(self): @property def n_output_features(self): - out1 = getattr(self._basis1, "n_output_features", None) - out2 = getattr(self._basis2, "n_output_features", None) + out1 = getattr(self.basis1, "n_output_features", None) + out2 = getattr(self.basis2, "n_output_features", None) if out1 is None or out2 is None: return None return out1 * out2 @@ -1205,8 +1205,8 @@ def _evaluate(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatri """ X = np.asarray( row_wise_kron( - self._basis1._evaluate(*xi[: self._basis1._n_input_dimensionality]), - self._basis2._evaluate(*xi[self._basis1._n_input_dimensionality :]), + self.basis1._evaluate(*xi[: self.basis1._n_input_dimensionality]), + self.basis2._evaluate(*xi[self.basis1._n_input_dimensionality :]), transpose=False, ) ) @@ -1239,8 +1239,8 @@ def _compute_features( """ kron = support_pynapple(conv_type="numpy")(row_wise_kron) X = kron( - self._basis1._compute_features(*xi[: self._basis1._n_input_dimensionality]), - self._basis2._compute_features(*xi[self._basis1._n_input_dimensionality :]), + self.basis1._compute_features(*xi[: self.basis1._n_input_dimensionality]), + self.basis2._compute_features(*xi[self.basis1._n_input_dimensionality :]), transpose=False, ) return X diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 3b40b4b6..c65710bd 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -567,29 +567,13 @@ def _set_input_independent_states(self): def _check_input_shape_consistency(self, *xi: NDArray): """Check the input shape consistency for all basis elements.""" - self._basis1._check_input_shape_consistency( - *xi[: self._basis1._n_input_dimensionality] + self.basis1._check_input_shape_consistency( + *xi[: self.basis1._n_input_dimensionality] ) - self._basis2._check_input_shape_consistency( - *xi[self._basis1._n_input_dimensionality :] + self.basis2._check_input_shape_consistency( + *xi[self.basis1._n_input_dimensionality :] ) - @property - def basis1(self): - return self._basis1 - - @basis1.setter - def basis1(self, bas: Basis): - self._basis1 = bas - - @property - def basis2(self): - return self._basis2 - - @basis2.setter - def basis2(self, bas: Basis): - self._basis2 = bas - def _iterate_over_components(self): """Return a generator that iterates over all basis components. @@ -601,8 +585,8 @@ def _iterate_over_components(self): A generator looping on each individual input. """ return chain( - self._basis1._iterate_over_components(), - self._basis2._iterate_over_components(), + self.basis1._iterate_over_components(), + self.basis2._iterate_over_components(), ) @set_input_shape_state @@ -616,8 +600,8 @@ def __sklearn_clone__(self) -> Basis: The method also handles recursive cloning for composite basis structures. """ # clone recursively - basis1 = self._basis1.__sklearn_clone__() - basis2 = self._basis2.__sklearn_clone__() + basis1 = self.basis1.__sklearn_clone__() + basis2 = self.basis2.__sklearn_clone__() klass = self.__class__(basis1, basis2) for attr_name in ["_n_basis_input_", "_input_shape_"]: @@ -653,11 +637,11 @@ def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis: Returns the instance itself to allow method chaining. """ self._n_basis_input_ = ( - *self._basis1.set_input_shape( - *xi[: self._basis1._n_input_dimensionality] + *self.basis1.set_input_shape( + *xi[: self.basis1._n_input_dimensionality] )._n_basis_input_, - *self._basis2.set_input_shape( - *xi[self._basis1._n_input_dimensionality :] + *self.basis2.set_input_shape( + *xi[self.basis1._n_input_dimensionality :] )._n_basis_input_, ) return self diff --git a/tests/test_basis.py b/tests/test_basis.py index e29b6cdd..8daeb683 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -2120,13 +2120,9 @@ def compare(b1, b2): if hasattr(b1, "basis1"): compare(b1.basis1, b2.basis1) compare(b1.basis2, b2.basis2) - # add all params that are not parent or _basis1,_basis2 - d1 = filter_attributes( - b1, exclude_keys=["_basis1", "_basis2", "_parent"] - ) - d2 = filter_attributes( - b2, exclude_keys=["_basis1", "_basis2", "_parent"] - ) + # add all params that are not parent or basis1,basis2 + d1 = filter_attributes(b1, exclude_keys=["basis1", "basis2", "_parent"]) + d2 = filter_attributes(b2, exclude_keys=["basis1", "basis2", "_parent"]) assert d1 == d2 else: decay_rates_b1 = b1.__dict__.get("_decay_rates", -1) @@ -2677,9 +2673,9 @@ def test_fit_kernel( def check_kernel(basis_obj): has_kern = [] - if hasattr(basis_obj, "_basis1"): - has_kern += check_kernel(basis_obj._basis1) - has_kern += check_kernel(basis_obj._basis2) + if hasattr(basis_obj, "basis1"): + has_kern += check_kernel(basis_obj.basis1) + has_kern += check_kernel(basis_obj.basis2) else: has_kern += [ basis_obj.kernel_ is not None if basis_obj.mode == "conv" else True @@ -3605,9 +3601,9 @@ def test_fit_kernel( def check_kernel(basis_obj): has_kern = [] - if hasattr(basis_obj, "_basis1"): - has_kern += check_kernel(basis_obj._basis1) - has_kern += check_kernel(basis_obj._basis2) + if hasattr(basis_obj, "basis1"): + has_kern += check_kernel(basis_obj.basis1) + has_kern += check_kernel(basis_obj.basis2) else: has_kern += [ basis_obj.kernel_ is not None if basis_obj.mode == "conv" else True From 3f6f89d32b07f49df052084ad91923fe4511105c Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 18 Dec 2024 16:24:55 -0500 Subject: [PATCH 40/41] roll-back exception handling in tutorial --- docs/how_to_guide/plot_05_transformer_basis.md | 8 ++++++-- docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/how_to_guide/plot_05_transformer_basis.md b/docs/how_to_guide/plot_05_transformer_basis.md index aed91e54..5cdd10ee 100644 --- a/docs/how_to_guide/plot_05_transformer_basis.md +++ b/docs/how_to_guide/plot_05_transformer_basis.md @@ -102,7 +102,6 @@ As with any `sckit-learn` transformer, the `TransformerBasis` implements `fit`, At this point we have an object equipped with the correct methods, so now, all we have to do is concatenate the inputs into a unique array and call `fit_transform`, right? ```{code-cell} ipython3 -:tags: [raises-exception] # reinstantiate the basis transformer for illustration porpuses composite_basis = counts_basis + speed_basis @@ -110,7 +109,12 @@ trans_bas = (composite_basis).to_transformer() # concatenate the inputs inp = np.concatenate([counts, speed[:, np.newaxis]], axis=1) print(inp.shape) -trans_bas.fit_transform(inp) + +try: + trans_bas.fit_transform(inp) +except RuntimeError as e: + print(repr(e)) + ``` ...Unfortunately, not yet. The problem is that the basis doesn't know which columns of `inp` should be processed by `count_basis` and which by `speed_basis`. diff --git a/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md index 04d1092b..4073b928 100644 --- a/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md +++ b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md @@ -169,7 +169,7 @@ trans_bas.set_input_shape(1) ``` :::{admonition} Learn More about `TransformerBasis` -:note: +:class: note To learn more about `sklearn` transformers and `TransforerBasis`, check out [this note](tansformer-vs-nemos-basis). ::: From a510ef3ca4f7c74ec7ad425cdd3093d15a813ba3 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 18 Dec 2024 16:56:25 -0500 Subject: [PATCH 41/41] improved layout how tos --- docs/how_to_guide/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/how_to_guide/README.md b/docs/how_to_guide/README.md index 432eb800..95094a5c 100644 --- a/docs/how_to_guide/README.md +++ b/docs/how_to_guide/README.md @@ -14,7 +14,7 @@ pip install nemos[examples] ::: -::::{grid} 1 2 3 4 +::::{grid} 1 2 3 3 :::{grid-item-card}