From ce5dff2bd9f2df1817b051e331e2b33c590ce630 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 12 Dec 2024 23:54:18 -0500 Subject: [PATCH] re-struct bases --- pyproject.toml | 5 + src/nemos/basis/_basis.py | 382 +++--- src/nemos/basis/_basis_mixin.py | 198 ++- src/nemos/basis/_decaying_exponential.py | 4 - src/nemos/basis/_raised_cosine_basis.py | 6 - src/nemos/basis/_spline_basis.py | 17 - src/nemos/basis/_transformer_basis.py | 51 +- src/nemos/basis/basis.py | 97 +- tests/conftest.py | 103 ++ tests/test_basis.py | 1426 +++++++++++----------- tests/test_pipeline.py | 17 +- 11 files changed, 1310 insertions(+), 996 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4c6134ef..d20fd307 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,6 +123,11 @@ profile = "black" # Configure pytest [tool.pytest.ini_options] testpaths = ["tests"] # Specify the directory where test files are located +filterwarnings = [ + # note the use of single quote below to denote "raw" strings in TOML + 'ignore:plotting functions contained within:UserWarning', + 'ignore:Tolerance of \d\.\d+e-\d\d reached:RuntimeWarning', +] [tool.coverage.run] omit = [ diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index d57b2ad2..adaa2c8a 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -53,7 +53,7 @@ def check_one_dimensional(func: Callable) -> Callable: """Check if the input is one-dimensional.""" @wraps(func) - def wrapper(self: Basis, *xi: ArrayLike, **kwargs): + def wrapper(self: Basis, *xi: NDArray, **kwargs): if any(x.ndim != 1 for x in xi): raise ValueError("Input sample must be one dimensional!") return func(self, *xi, **kwargs) @@ -135,28 +135,28 @@ class Basis(Base, abc.ABC, BasisTransformerMixin): def __init__( self, - n_basis_funcs: int, mode: Literal["eval", "conv"] = "eval", label: Optional[str] = None, ) -> None: - self.n_basis_funcs = n_basis_funcs - self._n_input_dimensionality = 0 + self._n_basis_funcs = getattr(self, "_n_basis_funcs", None) + self._n_input_dimensionality = getattr(self, "_n_input_dimensionality", 0) self._mode = mode - self._n_basis_input = None - - # these parameters are going to be set at the first call of `compute_features` - # since we cannot know a-priori how many features may be convolved - self._n_output_features = None - self._input_shape = None - if label is None: self._label = self.__class__.__name__ else: self._label = str(label) - self.kernel_ = None + self._check_n_basis_min() + + # specified only after inputs/input shapes are provided + self._n_basis_input_ = getattr(self, "_n_basis_input_", None) + self._input_shape_ = getattr(self, "_input_shape_", None) + + # initialize parent to None. This should not end in "_" because it is + # a permanent property of a basis, defined at composite basis init + self._parent = None @property def n_output_features(self) -> int | None: @@ -169,7 +169,9 @@ def n_output_features(self) -> int | None: provided to the basis is known. Therefore, before the first call to ``compute_features``, this property will return ``None``. After that call, ``n_output_features`` will be available. """ - return self._n_output_features + if self._n_basis_input_ is not None: + return self.n_basis_funcs * self._n_basis_input_[0] + return None @property def label(self) -> str: @@ -177,12 +179,12 @@ def label(self) -> str: return self._label @property - def n_basis_input(self) -> tuple | None: + def n_basis_input_(self) -> tuple | None: """Number of expected inputs. The number of inputs ``compute_feature`` expects. """ - return self._n_basis_input + return self._n_basis_input_ @property def n_basis_funcs(self): @@ -204,43 +206,6 @@ def mode(self): """Mode of operation, either ``"conv"`` or ``"eval"``.""" return self._mode - @staticmethod - def _apply_identifiability_constraints(X: NDArray): - """Apply identifiability constraints to a design matrix `X`. - - Removes columns from `X` until `[1, X]` is full rank to ensure the uniqueness - of the GLM (Generalized Linear Model) maximum-likelihood solution. This is particularly - crucial for models using bases like BSplines and CyclicBspline, which, due to their - construction, sum to 1 and can cause rank deficiency when combined with an intercept. - - For GLMs, this rank deficiency means that different sets of coefficients might yield - identical predicted rates and log-likelihood, complicating parameter learning, especially - in the absence of regularization. - - Parameters - ---------- - X: - The design matrix before applying the identifiability constraints. - - Returns - ------- - : - The adjusted design matrix with redundant columns dropped and columns mean-centered. - """ - - def add_constant(x): - return np.hstack((np.ones((x.shape[0], 1)), x)) - - rank = np.linalg.matrix_rank(add_constant(X)) - # mean center - X = X - np.nanmean(X, axis=0) - while rank < X.shape[1] + 1: - # drop a column - X = X[:, :-1] - # recompute rank - rank = np.linalg.matrix_rank(add_constant(X)) - return X - @check_transform_input def compute_features( self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor @@ -273,24 +238,112 @@ def compute_features( Subclasses should implement how to handle the transformation specific to their basis function types and operation modes. """ - if self._n_basis_input is None: + if self._n_basis_input_ is None: self.set_input_shape(*xi) self._check_input_shape_consistency(*xi) - self.set_kernel() + self._set_input_independent_states() return self._compute_features(*xi) @abc.abstractmethod def _compute_features( self, *xi: NDArray | Tsd | TsdFrame | TsdTensor ) -> FeatureMatrix: - """Convolve or evaluate the basis.""" + """Convolve or evaluate the basis. + + This method is intended to be equivalent to the sklearn transformer ``transform`` method. + As the latter, it computes the transformation assuming that all the states are already + pre-computed by ``_fit_basis``, a method corresponding to ``fit``. + + The method differs from transformer's ``transform`` for the structure of the input that it accepts. + In particular, ``_compute_features`` accepts a number of different time series, one per 1D basis component, + while ``transform`` requires all inputs to be concatenated in a single array. + """ + pass + + @abc.abstractmethod + def setup_basis(self, *xi: ArrayLike) -> FeatureMatrix: + """Pre-compute all basis state variables. + + This method is intended to be equivalent to the sklearn transformer ``fit`` method. + As the latter, it computes all the state attributes, and store it with the convention + that the attribute name **must** end with "_", for example ``self.kernel_``, + ``self._input_shape_``. + + The method differs from transformer's ``fit`` for the structure of the input that it accepts. + In particular, ``_fit_basis`` accepts a number of different time series, one per 1D basis component, + while ``fit`` requires all inputs to be concatenated in a single array. + """ pass @abc.abstractmethod - def set_kernel(self): - """Set kernel for conv basis and return self or just return self for eval.""" + def _set_input_independent_states(self): + """ + Compute all the basis states that do not depend on the input. + + An example of such state is the kernel_ for Conv baisis, which can be computed + without any input (it only depends on the basis type, the window size and the + number of basis elements). + """ pass + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): + """ + Set the expected input shape for the basis object. + + This method configures the shape of the input data that the basis object expects. + ``xi`` can be specified as an integer, a tuple of integers, or derived + from an array. The method also calculates the total number of input + features and output features based on the number of basis functions. + + Parameters + ---------- + xi : + The input shape specification. + - An integer: Represents the dimensionality of the input. A value of ``1`` is treated as scalar input. + - A tuple: Represents the exact input shape excluding the first axis (sample axis). + All elements must be integers. + - An array: The shape is extracted, excluding the first axis (assumed to be the sample axis). + + Raises + ------ + ValueError + If a tuple is provided and it contains non-integer elements. + + Returns + ------- + self : + Returns the instance itself to allow method chaining. + + Notes + ----- + All state attributes that depends on the input must be set in this method in order for + the API of basis to work correctly. In particular, this method is called by ``_basis_fit``, + which is equivalent to ``fit`` for a transformer. If any input dependent state + is not set in this method, then ``compute_features`` (equivalent to ``fit_transform``) will break. + + Separating states related to the input (settable with this method) and states that are unrelated + from the input (settable with ``set_kernel`` for Conv bases) is a deliberate design choice + that improves modularity. + + """ + if isinstance(xi, tuple): + if not all(isinstance(i, int) for i in xi): + raise ValueError( + f"The tuple provided contains non integer values. Tuple: {xi}." + ) + shape = xi + elif isinstance(xi, int): + shape = () if xi == 1 else (xi,) + else: + shape = xi.shape[1:] + + n_inputs = (int(np.prod(shape)),) + + self._input_shape_ = shape + + self._n_basis_input_ = n_inputs + return self + @abc.abstractmethod def _evaluate(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: """ @@ -381,13 +434,6 @@ def _check_transform_input( return xi - def _check_has_kernel(self) -> None: - """Check that the kernel is pre-computed.""" - if self.mode == "conv" and self.kernel_ is None: - raise ValueError( - "You must call `_set_kernel` before `_compute_features` when mode =`conv`." - ) - def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: """Evaluate the basis set on a grid of equi-spaced sample points. @@ -560,7 +606,7 @@ def _get_feature_slicing( Parameters ---------- n_inputs : - The number of input basis for each component, by default it uses ``self._n_basis_input``. + The number of input basis for each component, by default it uses ``self._n_basis_input_``. start_slice : The starting index for slicing, by default it starts from 0. split_by_input : @@ -582,9 +628,8 @@ def _get_feature_slicing( _merge_slicing_dicts : Merges multiple slicing dictionaries, handling keys conflicts. """ # Set default values for n_inputs and start_slice if not provided - n_inputs = n_inputs or self._n_basis_input + n_inputs = n_inputs or self._n_basis_input_ start_slice = start_slice or 0 - # Handle the default case for non-additive basis types # See overwritten method for recursion logic split_dict, start_slice = self._get_default_slicing( @@ -609,11 +654,9 @@ def _get_default_slicing( """Handle default slicing logic.""" if split_by_input: # should we remove this option? - if self._n_basis_input[0] == 1 or isinstance(self, MultiplicativeBasis): + if self._n_basis_input_[0] == 1 or isinstance(self, MultiplicativeBasis): split_dict = { - self.label: slice( - start_slice, start_slice + self._n_output_features - ) + self.label: slice(start_slice, start_slice + self.n_output_features) } else: split_dict = { @@ -622,14 +665,14 @@ def _get_default_slicing( start_slice + i * self.n_basis_funcs, start_slice + (i + 1) * self.n_basis_funcs, ) - for i in range(self._n_basis_input[0]) + for i in range(self._n_basis_input_[0]) } } else: split_dict = { - self.label: slice(start_slice, start_slice + self._n_output_features) + self.label: slice(start_slice, start_slice + self.n_output_features) } - start_slice += self._n_output_features + start_slice += self.n_output_features return split_dict, start_slice def split_by_feature( @@ -721,13 +764,13 @@ def is_leaf(val): # Apply the slicing using the custom leaf function out = jax.tree_util.tree_map(lambda sl: x[sl], index_dict, is_leaf=is_leaf) - # reshape the arrays to spilt by n_basis_input + # reshape the arrays to spilt by n_basis_input_ reshaped_out = dict() for i, vals in enumerate(out.items()): key, val = vals shape = list(val.shape) reshaped_out[key] = val.reshape( - shape[:axis] + [self._n_basis_input[i], -1] + shape[axis + 1 :] + shape[:axis] + [self._n_basis_input_[i], -1] + shape[axis + 1 :] ) return reshaped_out @@ -736,10 +779,10 @@ def _check_input_shape_consistency(self, x: NDArray): # remove sample axis and squeeze shape = x.shape[1:] - initialized = self._input_shape is not None - is_shape_match = self._input_shape == shape + initialized = self._input_shape_ is not None + is_shape_match = self._input_shape_ == shape if initialized and not is_shape_match: - expected_shape_str = "(n_samples, " + f"{self._input_shape}"[1:] + expected_shape_str = "(n_samples, " + f"{self._input_shape_}"[1:] expected_shape_str = expected_shape_str.replace(",)", ")") raise ValueError( f"Input shape mismatch detected.\n\n" @@ -752,52 +795,50 @@ def _check_input_shape_consistency(self, x: NDArray): "different shape, please create a new basis instance." ) - def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): - """ - Set the expected input shape for the basis object. + def _list_components(self): + """List all basis components. - This method configures the shape of the input data that the basis object expects. - ``xi`` can be specified as an integer, a tuple of integers, or derived - from an array. The method also calculates the total number of input - features and output features based on the number of basis functions. + This is re-implemented for composite basis in the mixin class. - Parameters - ---------- - xi : - The input shape specification. - - An integer: Represents the dimensionality of the input. A value of ``1`` is treated as scalar input. - - A tuple: Represents the exact input shape excluding the first axis (sample axis). - All elements must be integers. - - An array: The shape is extracted, excluding the first axis (assumed to be the sample axis). + Returns + ------- + A list with all 1d basis components. Raises ------ - ValueError - If a tuple is provided and it contains non-integer elements. - - Returns - ------- - self : - Returns the instance itself to allow method chaining. + RuntimeError + If the basis has multiple components. This would only happen if there is an + implementation issue, for example, if a composite basis is implemented but the + mixin class is not initialized, or if the _list_components method of the composite mixin + class is accidentally removed. """ - if isinstance(xi, tuple): - if not all(isinstance(i, int) for i in xi): - raise ValueError( - f"The tuple provided contains non integer values. Tuple: {xi}." - ) - shape = xi - elif isinstance(xi, int): - shape = () if xi == 1 else (xi,) - else: - shape = xi.shape[1:] + if hasattr(self, "basis1"): + raise RuntimeError( + "Composite basis must implement the _list_components method." + ) + return [self] - n_inputs = (int(np.prod(shape)),) + def __sklearn_clone__(self) -> Basis: + """Clone the basis while preserving attributes related to input shapes. - self._input_shape = shape + This method ensures that input shape attributes (e.g., `_n_basis_input_`, + `_input_shape_`) are preserved during cloning. Reinitializing the class + as in the regular sklearn clone would drop these attributes, rendering + cross-validation unusable. + The method also handles recursive cloning for composite basis structures. + """ + # clone recursively + if hasattr(self, "_basis1") and hasattr(self, "_basis2"): + basis1 = self._basis1.__sklearn_clone__() + basis2 = self._basis2.__sklearn_clone__() + klass = self.__class__(basis1, basis2) - self._n_basis_input = n_inputs - self._n_output_features = self.n_basis_funcs * self._n_basis_input[0] - return self + else: + klass = self.__class__(**self.get_params()) + + for attr_name in ["_n_basis_input_", "_input_shape_"]: + setattr(klass, attr_name, getattr(self, attr_name)) + return klass class AdditiveBasis(CompositeBasisMixin, Basis): @@ -811,11 +852,6 @@ class AdditiveBasis(CompositeBasisMixin, Basis): basis2 : Second basis object to add. - Attributes - ---------- - n_basis_funcs : int - Number of basis functions. - Examples -------- >>> # Generate sample data @@ -835,17 +871,31 @@ class AdditiveBasis(CompositeBasisMixin, Basis): """ def __init__(self, basis1: Basis, basis2: Basis) -> None: - self.n_basis_funcs = basis1.n_basis_funcs + basis2.n_basis_funcs - super().__init__(self.n_basis_funcs, mode="eval") + CompositeBasisMixin.__init__(self, basis1, basis2) + Basis.__init__(self, mode="eval") + self._label = "(" + basis1.label + " + " + basis2.label + ")" + self._n_input_dimensionality = ( basis1._n_input_dimensionality + basis2._n_input_dimensionality ) - self._n_basis_input = None - self._n_output_features = None - self._label = "(" + basis1.label + " + " + basis2.label + ")" - self._basis1 = basis1 - self._basis2 = basis2 - CompositeBasisMixin.__init__(self) + + @property + def n_basis_funcs(self): + """Compute the n-basis function runtime. + + This plays well with cross-validation where the number of basis function of the + underlying bases can be changed. It must be read-only since the number of basis + is determined by the two basis elements and the type of composition. + """ + return self.basis1.n_basis_funcs + self.basis2.n_basis_funcs + + @property + def n_output_features(self): + out1 = getattr(self._basis1, "n_output_features", None) + out2 = getattr(self._basis2, "n_output_features", None) + if out1 is None or out2 is None: + return None + return out1 + out2 def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis: """ @@ -896,16 +946,13 @@ def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis: 181 """ - self._n_basis_input = ( + self._n_basis_input_ = ( *self._basis1.set_input_shape( *xi[: self._basis1._n_input_dimensionality] - )._n_basis_input, + )._n_basis_input_, *self._basis2.set_input_shape( *xi[self._basis1._n_input_dimensionality :] - )._n_basis_input, - ) - self._n_output_features = ( - self._basis1.n_output_features + self._basis2.n_output_features + )._n_basis_input_, ) return self @@ -1209,18 +1256,18 @@ def _get_feature_slicing( _merge_slicing_dicts : Merges multiple slicing dictionaries, handling keys conflicts. """ # Set default values for n_inputs and start_slice if not provided - n_inputs = n_inputs or self._n_basis_input + n_inputs = n_inputs or self._n_basis_input_ start_slice = start_slice or 0 # If the instance is of AdditiveBasis type, handle slicing for the additive components split_dict, start_slice = self._basis1._get_feature_slicing( - n_inputs[: len(self._basis1._n_basis_input)], + n_inputs[: len(self._basis1._n_basis_input_)], start_slice, split_by_input=split_by_input, ) sp2, start_slice = self._basis2._get_feature_slicing( - n_inputs[len(self._basis1._n_basis_input) :], + n_inputs[len(self._basis1._n_basis_input_) :], start_slice, split_by_input=split_by_input, ) @@ -1249,11 +1296,6 @@ class MultiplicativeBasis(CompositeBasisMixin, Basis): basis2 : Second basis object to multiply. - Attributes - ---------- - n_basis_funcs : - Number of basis functions. - Examples -------- >>> # Generate sample data @@ -1273,39 +1315,30 @@ class MultiplicativeBasis(CompositeBasisMixin, Basis): """ def __init__(self, basis1: Basis, basis2: Basis) -> None: - self.n_basis_funcs = basis1.n_basis_funcs * basis2.n_basis_funcs - CompositeBasisMixin.__init__(self) - super().__init__(self.n_basis_funcs, mode="eval") + CompositeBasisMixin.__init__(self, basis1, basis2) + Basis.__init__(self, mode="eval") + self._label = "(" + basis1.label + " * " + basis2.label + ")" self._n_input_dimensionality = ( basis1._n_input_dimensionality + basis2._n_input_dimensionality ) - self._n_basis_input = None - self._n_output_features = None - self._label = "(" + basis1.label + " * " + basis2.label + ")" - self._basis1 = basis1 - self._basis2 = basis2 - - def _check_n_basis_min(self) -> None: - pass - - def set_kernel(self, *xi: NDArray) -> Basis: - """Call fit on the multiplied basis. - - If any of the added basis is in "conv" mode, it will prepare its kernels for the convolution. - Parameters - ---------- - *xi: - The sample inputs. Unused, necessary to conform to ``scikit-learn`` API. + @property + def n_basis_funcs(self): + """Compute the n-basis function runtime. - Returns - ------- - : - The MultiplicativeBasis ready to be evaluated. + This plays well with cross-validation where the number of basis function of the + underlying bases can be changed. It must be read-only since the number of basis + is determined by the two basis elements and the type of composition. """ - self._basis1.set_kernel() - self._basis2.set_kernel() - return self + return self.basis1.n_basis_funcs * self.basis2.n_basis_funcs + + @property + def n_output_features(self): + out1 = getattr(self._basis1, "n_output_features", None) + out2 = getattr(self._basis2, "n_output_features", None) + if out1 is None or out2 is None: + return None + return out1 * out2 @support_pynapple(conv_type="numpy") @check_transform_input @@ -1423,16 +1456,13 @@ def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis: 25200 """ - self._n_basis_input = ( + self._n_basis_input_ = ( *self._basis1.set_input_shape( *xi[: self._basis1._n_input_dimensionality] - )._n_basis_input, + )._n_basis_input_, *self._basis2.set_input_shape( *xi[self._basis1._n_input_dimensionality :] - )._n_basis_input, - ) - self._n_output_features = ( - self._basis1.n_output_features * self._basis2.n_output_features + )._n_basis_input_, ) return self diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 098aeb34..9b208e31 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -2,8 +2,10 @@ from __future__ import annotations +import abc import copy import inspect +import warnings from typing import TYPE_CHECKING, Optional, Tuple, Union import numpy as np @@ -20,8 +22,11 @@ class EvalBasisMixin: """Mixin class for evaluational basis.""" - def __init__(self, bounds: Optional[Tuple[float, float]] = None): + def __init__( + self, n_basis_funcs: int, bounds: Optional[Tuple[float, float]] = None + ): self.bounds = bounds + self._n_basis_funcs = n_basis_funcs def _compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor): """Evaluate basis at sample points. @@ -51,9 +56,32 @@ def _compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor): out = self._evaluate(*(np.reshape(x, (x.shape[0], -1)) for x in xi)) return np.reshape(out, (out.shape[0], -1)) - def set_kernel(self) -> "EvalBasisMixin": + def setup_basis(self, *xi: NDArray) -> Basis: """ - Prepare or compute the convolutional kernel for the basis functions. + Set all basis states. + + This method corresponds sklearn transformer ``fit``. As fit, it must receive the input and + it must set all basis states, i.e. kernel_ and all the states relative to the input shape. + The difference between this method and the transformer ``fit`` is in the expected input structure, + where the transformer ``fit`` method requires the inputs to be concatenated in a 2D array, while here + each input is provided as a separate time series for each basis element. + + Parameters + ---------- + xi: + Input arrays. + + Returns + ------- + : + The basis with ready for evaluation. + """ + self.set_input_shape(*xi) + return self + + def _set_input_independent_states(self) -> "EvalBasisMixin": + """ + Compute all the basis states that do not depend on the input. For EvalBasisMixin, this method might not perform any operation but simply return the instance itself, as no kernel preparation is necessary. @@ -94,9 +122,13 @@ def bounds(self, values: Union[None, Tuple[float, float]]): class ConvBasisMixin: """Mixin class for convolutional basis.""" - def __init__(self, window_size: int, conv_kwargs: Optional[dict] = None): + def __init__( + self, n_basis_funcs: int, window_size: int, conv_kwargs: Optional[dict] = None + ): + self.kernel_ = None self.window_size = window_size self.conv_kwargs = {} if conv_kwargs is None else conv_kwargs + self._n_basis_funcs = n_basis_funcs def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor): """Convolve basis functions with input time series. @@ -114,10 +146,18 @@ def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor): The input data over which to apply the basis transformation. The samples can be passed as multiple arguments, each representing a different dimension for multivariate inputs. + Notes + ----- + This method is intended to be 1-to-1 mappable to sklearn ``transform`` method of transformer. This + means that for the method to be callable, all the state attributes have to be pre-computed in a + method that is mappable to ``fit``, which for us is ``_fit_basis``. It is fundamental that both + methods behaves like the corresponding transformer method, with the only difference being the input + structure: a single (X, y) pair for the transformer, a number of time series for the Basis. + """ if self.kernel_ is None: raise ValueError( - "You must call `_set_kernel` before `_compute_features`! " + "You must call `setup_basis` before `_compute_features`! " "Convolution kernel is not set." ) # before calling the convolve, check that the input matches @@ -127,6 +167,38 @@ def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor): # make sure to return a matrix return np.reshape(conv, newshape=(conv.shape[0], -1)) + def setup_basis(self, *xi: NDArray) -> Basis: + """ + Set all basis states. + + This method corresponds sklearn transformer ``fit``. As fit, it must receive the input and + it must set all basis states, i.e. kernel_ and all the states relative to the input shape. + The difference between this method and the transformer ``fit`` is in the expected input structure, + where the transformer ``fit`` method requires the inputs to be concatenated in a 2D array, while here + each input is provided as a separate time series for each basis element. + + Parameters + ---------- + xi: + Input arrays. + + Returns + ------- + : + The basis with ready for evaluation. + """ + self.set_kernel() + self.set_input_shape(*xi) + return self + + def _set_input_independent_states(self): + """ + Compute all the basis states that do not depend on the input. + + For Conv mixin the only attribute is the kernel. + """ + return self.set_kernel() + def set_kernel(self) -> "ConvBasisMixin": """ Prepare or compute the convolutional kernel for the basis functions. @@ -160,6 +232,11 @@ def window_size(self): @window_size.setter def window_size(self, window_size): """Setter for the window size parameter.""" + self._check_window_size(window_size) + + self._window_size = window_size + + def _check_window_size(self, window_size): if window_size is None: raise ValueError("You must provide a window_size!") @@ -168,8 +245,6 @@ def window_size(self, window_size): f"`window_size` must be a positive integer. {window_size} provided instead!" ) - self._window_size = window_size - @property def conv_kwargs(self): """The convolutional kwargs. @@ -227,6 +302,13 @@ def _check_convolution_kwargs(conv_kwargs: dict): f"Allowed convolution keyword arguments are: {convolve_configs}." ) + def _check_has_kernel(self) -> None: + """Check that the kernel is pre-computed.""" + if self.kernel_ is None: + raise ValueError( + "You must call `_set_kernel` before `_compute_features` for Conv basis." + ) + class BasisTransformerMixin: """Mixin class for constructing a transformer.""" @@ -244,7 +326,7 @@ def to_transformer(self) -> TransformerBasis: >>> from sklearn.model_selection import GridSearchCV >>> # load some data >>> X, y = np.random.normal(size=(30, 1)), np.random.poisson(size=30) - >>> basis = nmo.basis.RaisedCosineLinearEval(10).to_transformer() + >>> basis = nmo.basis.RaisedCosineLinearEval(10).set_input_shape(1).to_transformer() >>> glm = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=1.) >>> pipeline = Pipeline([("basis", basis), ("glm", glm)]) >>> param_grid = dict( @@ -258,7 +340,7 @@ def to_transformer(self) -> TransformerBasis: ... ) >>> gridsearch = gridsearch.fit(X, y) """ - return TransformerBasis(copy.deepcopy(self)) + return TransformerBasis(self) class CompositeBasisMixin: @@ -268,28 +350,82 @@ class CompositeBasisMixin: (AdditiveBasis and MultiplicativeBasis). """ + def __init__(self, basis1: Basis, basis2: Basis): + # deep copy to avoid changes directly to the 1d basis to be reflected + # in the composite basis. + self.basis1 = copy.deepcopy(basis1) + self.basis2 = copy.deepcopy(basis2) + + # set parents + self.basis1._parent = self + self.basis2._parent = self + + shapes = ( + *(bas1._input_shape_ for bas1 in basis1._list_components()), + *(bas2._input_shape_ for bas2 in basis2._list_components()), + ) + # if all bases where set, then set input for composition. + set_bases = (s is not None for s in shapes) + + if all(set_bases): + # pass down the input shapes + self.set_input_shape(*shapes) + elif any(set_bases): + warnings.warn( + "Only some of the basis where initialized with `set_input_shape`, " + "please initialize the composite basis before computing features.", + category=UserWarning, + ) + + @property + @abc.abstractmethod + def n_basis_funcs(self): + """Read only property for composite bases.""" + pass + def _check_n_basis_min(self) -> None: pass - def set_kernel(self, *xi: NDArray) -> Basis: - """Call set_kernel on the basis elements. + def setup_basis(self, *xi: NDArray) -> Basis: + """ + Set all basis states. - If any of the basis elements is in "conv" mode, it will prepare its kernels for the convolution. + This method corresponds sklearn transformer ``fit``. As fit, it must receive the input and + it must set all basis states, i.e. kernel_ and all the states relative to the input shape. + The difference between this method and the transformer ``fit`` is in the expected input structure, + where the transformer ``fit`` method requires the inputs to be concatenated in a 2D array, while here + each input is provided as a separate time series for each basis element. Parameters ---------- - *xi: - The sample inputs. Unused, necessary to conform to ``scikit-learn`` API. + xi: + Input arrays. Returns ------- : - The basis ready to be evaluated. + The basis with ready for evaluation. """ - self._basis1.set_kernel() - self._basis2.set_kernel() + # setup both input independent + self._set_input_independent_states() + + # and input dependent states + self.set_input_shape(*xi) + return self + def _set_input_independent_states(self): + """ + Compute the input dependent states for traversing the composite basis. + + Returns + ------- + : + The basis with the states stored as attributes of each component. + """ + self.basis1._set_input_independent_states() + self.basis2._set_input_independent_states() + def _check_input_shape_consistency(self, *xi: NDArray): """Check the input shape consistency for all basis elements.""" self._basis1._check_input_shape_consistency( @@ -298,3 +434,31 @@ def _check_input_shape_consistency(self, *xi: NDArray): self._basis2._check_input_shape_consistency( *xi[self._basis1._n_input_dimensionality :] ) + + @property + def basis1(self): + return self._basis1 + + @basis1.setter + def basis1(self, bas: Basis): + self._basis1 = bas + + @property + def basis2(self): + return self._basis2 + + @basis2.setter + def basis2(self, bas: Basis): + self._basis2 = bas + + def _list_components(self): + """List all basis components. + + Reimplements the default behavior by iteratively calling _list_components of the + elements. + + Returns + ------- + A list with all 1d basis components. + """ + return self._basis1._list_components() + self._basis2._list_components() diff --git a/src/nemos/basis/_decaying_exponential.py b/src/nemos/basis/_decaying_exponential.py index 65a71a3e..7df05947 100644 --- a/src/nemos/basis/_decaying_exponential.py +++ b/src/nemos/basis/_decaying_exponential.py @@ -21,8 +21,6 @@ class OrthExponentialBasis(Basis, abc.ABC): Parameters ---------- - n_basis_funcs - Number of basis functions. decay_rates : Decay rates of the exponentials, shape ``(n_basis_funcs,)``. mode : @@ -35,13 +33,11 @@ class OrthExponentialBasis(Basis, abc.ABC): def __init__( self, - n_basis_funcs: int, decay_rates: NDArray[np.floating], mode="eval", label: Optional[str] = "OrthExponentialBasis", ): super().__init__( - n_basis_funcs, mode=mode, label=label, ) diff --git a/src/nemos/basis/_raised_cosine_basis.py b/src/nemos/basis/_raised_cosine_basis.py index 07c3ae0a..0521a683 100644 --- a/src/nemos/basis/_raised_cosine_basis.py +++ b/src/nemos/basis/_raised_cosine_basis.py @@ -21,8 +21,6 @@ class RaisedCosineBasisLinear(Basis, abc.ABC): Parameters ---------- - n_basis_funcs : - The number of basis functions. mode : The mode of operation. 'eval' for evaluation at sample points, 'conv' for convolutional operation. @@ -42,13 +40,11 @@ class RaisedCosineBasisLinear(Basis, abc.ABC): def __init__( self, - n_basis_funcs: int, mode="eval", width: float = 2.0, label: Optional[str] = "RaisedCosineBasisLinear", ) -> None: super().__init__( - n_basis_funcs, mode=mode, label=label, ) @@ -234,7 +230,6 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear, abc.ABC): def __init__( self, - n_basis_funcs: int, mode="eval", width: float = 2.0, time_scaling: float = None, @@ -242,7 +237,6 @@ def __init__( label: Optional[str] = "RaisedCosineBasisLog", ) -> None: super().__init__( - n_basis_funcs, mode=mode, width=width, label=label, diff --git a/src/nemos/basis/_spline_basis.py b/src/nemos/basis/_spline_basis.py index d9969029..5fc4c38e 100644 --- a/src/nemos/basis/_spline_basis.py +++ b/src/nemos/basis/_spline_basis.py @@ -21,8 +21,6 @@ class SplineBasis(Basis, abc.ABC): Parameters ---------- - n_basis_funcs : - Number of basis functions. mode : The mode of operation. 'eval' for evaluation at sample points, 'conv' for convolutional operation. @@ -40,14 +38,12 @@ class SplineBasis(Basis, abc.ABC): def __init__( self, - n_basis_funcs: int, order: int = 2, label: Optional[str] = None, mode: Literal["conv", "eval"] = "eval", ) -> None: self.order = order super().__init__( - n_basis_funcs, label=label, mode=mode, ) @@ -158,9 +154,6 @@ class MSplineBasis(SplineBasis, abc.ABC): Parameters ---------- - n_basis_funcs : - The number of basis functions to generate. More basis functions allow for - more flexible data modeling but can lead to overfitting. mode : The mode of operation. 'eval' for evaluation at sample points, 'conv' for convolutional operation. @@ -198,13 +191,11 @@ class MSplineBasis(SplineBasis, abc.ABC): def __init__( self, - n_basis_funcs: int, mode: Literal["eval", "conv"] = "eval", order: int = 2, label: Optional[str] = "MSplineEval", ) -> None: super().__init__( - n_basis_funcs, mode=mode, order=order, label=label, @@ -301,8 +292,6 @@ class BSplineBasis(SplineBasis, abc.ABC): Parameters ---------- - n_basis_funcs : - Number of basis functions. mode : The mode of operation. ``'eval'`` for evaluation at sample points, 'conv' for convolutional operation. @@ -328,13 +317,11 @@ class BSplineBasis(SplineBasis, abc.ABC): def __init__( self, - n_basis_funcs: int, mode="eval", order: int = 4, label: Optional[str] = "BSplineBasis", ): super().__init__( - n_basis_funcs, mode=mode, order=order, label=label, @@ -419,8 +406,6 @@ class CyclicBSplineBasis(SplineBasis, abc.ABC): Parameters ---------- - n_basis_funcs : - Number of basis functions. mode : The mode of operation. 'eval' for evaluation at sample points, 'conv' for convolutional operation. @@ -442,13 +427,11 @@ class CyclicBSplineBasis(SplineBasis, abc.ABC): def __init__( self, - n_basis_funcs: int, mode="eval", order: int = 4, label: Optional[str] = "CyclicBSplineBasis", ): super().__init__( - n_basis_funcs, mode=mode, order=order, label=label, diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index 83f4f2e3..91865028 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -1,7 +1,9 @@ from __future__ import annotations import copy -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List + +import numpy as np from ..typing import FeatureMatrix @@ -63,12 +65,28 @@ def __init__(self, basis: Basis): self._basis = copy.deepcopy(basis) @staticmethod - def _unpack_inputs(X: FeatureMatrix): - """Unpack inputs without using transpose. + def _check_initialized(basis): + if basis._n_basis_input_ is None: + raise RuntimeError( + "Cannot apply TransformerBasis: the provided basis has no defined input shape. " + "Please call `set_input_shape` before calling `fit`, `transform`, or " + "`fit_transform`." + ) + + @property + def basis(self): + return self._basis + + @basis.setter + def basis(self, basis): + self._check_initialized(basis) + self._basis = basis + + def _unpack_inputs(self, X: FeatureMatrix) -> List: + """Unpack inputs. Unpack horizontally stacked inputs using slicing. This works gracefully with ``pynapple``, - returning a list of Tsd objects. Attempt to unpack using *X.T will raise a ``pynapple`` - exception since ``pynapple`` assumes that the time axis is the first axis. + returning a list of Tsd objects. Parameters ---------- @@ -78,10 +96,19 @@ def _unpack_inputs(X: FeatureMatrix): Returns ------- : - A tuple of each individual input. + A list of each individual input. """ - return (X[:, k] for k in range(X.shape[1])) + n_samples = X.shape[0] + out = [] + cc = 0 + for i, bas in enumerate(self._list_components()): + n_input = self._n_basis_input_[i] + out.append( + np.reshape(X[:, cc : cc + n_input], (n_samples, *bas._input_shape_)) + ) + cc += n_input + return out def fit(self, X: FeatureMatrix, y=None): """ @@ -110,11 +137,11 @@ def fit(self, X: FeatureMatrix, y=None): >>> X = np.random.normal(size=(100, 2)) >>> # Define and fit tranformation basis - >>> basis = MSplineEval(10) + >>> basis = MSplineEval(10).set_input_shape(2) >>> transformer = TransformerBasis(basis) >>> transformer_fitted = transformer.fit(X) """ - self._basis.set_kernel() + self._basis.setup_basis(*self._unpack_inputs(X)) return self def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: @@ -141,7 +168,7 @@ def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: >>> # Example input >>> X = np.random.normal(size=(10000, 2)) - >>> basis = MSplineConv(10, window_size=200) + >>> basis = MSplineConv(10, window_size=200).set_input_shape(2) >>> transformer = TransformerBasis(basis) >>> # Before calling `fit` the convolution kernel is not set >>> transformer.kernel_ @@ -152,7 +179,7 @@ def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: (200, 10) >>> # Transform basis - >>> feature_transformed = transformer.transform(X[:, 0:1]) + >>> feature_transformed = transformer.transform(X) """ # transpose does not work with pynapple # can't use func(*X.T) to unwrap @@ -187,7 +214,7 @@ def fit_transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: >>> X = np.random.normal(size=(100, 1)) >>> # Define tranformation basis - >>> basis = MSplineEval(10) + >>> basis = MSplineEval(10).set_input_shape(1) >>> transformer = TransformerBasis(basis) >>> # Fit and transform basis diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 0fee5651..702f7b56 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -83,10 +83,9 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "BSplineEval", ): - EvalBasisMixin.__init__(self, bounds=bounds) + EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds) BSplineBasis.__init__( self, - n_basis_funcs, mode="eval", order=order, label=label, @@ -237,10 +236,14 @@ def __init__( label: Optional[str] = "BSplineConv", conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) + ConvBasisMixin.__init__( + self, + n_basis_funcs=n_basis_funcs, + window_size=window_size, + conv_kwargs=conv_kwargs, + ) BSplineBasis.__init__( self, - n_basis_funcs, mode="conv", order=order, label=label, @@ -378,10 +381,9 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "CyclicBSplineEval", ): - EvalBasisMixin.__init__(self, bounds=bounds) + EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds) CyclicBSplineBasis.__init__( self, - n_basis_funcs, mode="eval", order=order, label=label, @@ -524,10 +526,14 @@ def __init__( label: Optional[str] = "CyclicBSplineConv", conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) + ConvBasisMixin.__init__( + self, + n_basis_funcs=n_basis_funcs, + window_size=window_size, + conv_kwargs=conv_kwargs, + ) CyclicBSplineBasis.__init__( self, - n_basis_funcs, mode="conv", order=order, label=label, @@ -689,10 +695,9 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "MSplineEval", ): - EvalBasisMixin.__init__(self, bounds=bounds) + EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds) MSplineBasis.__init__( self, - n_basis_funcs, mode="eval", order=order, label=label, @@ -859,10 +864,14 @@ def __init__( label: Optional[str] = "MSplineConv", conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) + ConvBasisMixin.__init__( + self, + n_basis_funcs=n_basis_funcs, + window_size=window_size, + conv_kwargs=conv_kwargs, + ) MSplineBasis.__init__( self, - n_basis_funcs, mode="conv", order=order, label=label, @@ -1008,10 +1017,9 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "RaisedCosineLinearEval", ): - EvalBasisMixin.__init__(self, bounds=bounds) + EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds) RaisedCosineBasisLinear.__init__( self, - n_basis_funcs, width=width, mode="eval", label=label, @@ -1155,10 +1163,14 @@ def __init__( label: Optional[str] = "RaisedCosineLinearConv", conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) + ConvBasisMixin.__init__( + self, + n_basis_funcs=n_basis_funcs, + window_size=window_size, + conv_kwargs=conv_kwargs, + ) RaisedCosineBasisLinear.__init__( self, - n_basis_funcs, mode="conv", width=width, label=label, @@ -1311,10 +1323,9 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "RaisedCosineLogEval", ): - EvalBasisMixin.__init__(self, bounds=bounds) + EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds) RaisedCosineBasisLog.__init__( self, - n_basis_funcs, width=width, time_scaling=time_scaling, enforce_decay_to_zero=enforce_decay_to_zero, @@ -1470,10 +1481,14 @@ def __init__( label: Optional[str] = "RaisedCosineLogConv", conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) + ConvBasisMixin.__init__( + self, + n_basis_funcs=n_basis_funcs, + window_size=window_size, + conv_kwargs=conv_kwargs, + ) RaisedCosineBasisLog.__init__( self, - n_basis_funcs, mode="conv", width=width, time_scaling=time_scaling, @@ -1608,10 +1623,9 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "OrthExponentialEval", ): - EvalBasisMixin.__init__(self, bounds=bounds) + EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds) OrthExponentialBasis.__init__( self, - n_basis_funcs, decay_rates=decay_rates, mode="eval", label=label, @@ -1751,14 +1765,21 @@ def __init__( label: Optional[str] = "OrthExponentialConv", conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) + ConvBasisMixin.__init__( + self, + n_basis_funcs=n_basis_funcs, + window_size=window_size, + conv_kwargs=conv_kwargs, + ) OrthExponentialBasis.__init__( self, - n_basis_funcs, mode="conv", decay_rates=decay_rates, label=label, ) + # re-check window size because n_basis_funcs is not set yet when the + # property setter runs the first check. + self._check_window_size(self.window_size) @add_docstring("evaluate_on_grid", OrthExponentialBasis) def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: @@ -1850,3 +1871,31 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ return super().set_input_shape(xi) + + def _check_window_size(self, window_size: int): + """OrthExponentialBasis specific window size check.""" + super()._check_window_size(window_size) + # if n_basis_funcs is not yet initialized, skip check + n_basis = getattr(self, "n_basis_funcs", None) + if n_basis and window_size < n_basis: + raise ValueError( + "OrthExponentialConv basis requires at least a window_size larger then the number " + f"of basis functions. window_size is {window_size}, n_basis_funcs while" + f"is {self.n_basis_funcs}." + ) + + def set_kernel(self): + try: + super().set_kernel() + except ValueError as e: + if "OrthExponentialBasis requires at least as many" in str(e): + raise ValueError( + "Cannot set the kernels for OrthExponentialBasis when `window_size` is smaller " + "than `n_basis_funcs.\n" + "Please, increase the window size or reduce the number of basis functions. " + f"Current `window_size` is {self.window_size}, while `n_basis_funcs` is " + f"{self.n_basis_funcs}." + ) + else: + raise e + return self diff --git a/tests/conftest.py b/tests/conftest.py index eb88ed10..3daba960 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,8 @@ and loading predefined parameters for testing various functionalities of the NeMoS library. """ +import abc + import jax import jax.numpy as jnp import numpy as np @@ -16,11 +18,112 @@ import pytest import nemos as nmo +import nemos._inspect_utils as inspect_utils +import nemos.basis.basis as basis +from nemos.basis import AdditiveBasis, MultiplicativeBasis +from nemos.basis._basis import Basis # shut-off conversion warnings nap.nap_config.suppress_conversion_warnings = True +@pytest.fixture() +def basis_class_specific_params(): + """Returns all the params for each class.""" + all_cls = list_all_basis_classes("Conv") + list_all_basis_classes("Eval") + return {cls.__name__: cls._get_param_names() for cls in all_cls} + + +class BasisFuncsTesting(abc.ABC): + """ + An abstract base class that sets the foundation for individual basis function testing. + This class requires an implementation of a 'cls' method, which is utilized by the meta-test + that verifies if all basis functions are properly tested. + """ + + @abc.abstractmethod + def cls(self): + pass + + +class CombinedBasis(BasisFuncsTesting): + """ + This class is used to run tests on combination operations (e.g., addition, multiplication) among Basis functions. + + Properties: + - cls: Class (default = None) + """ + + cls = None + + @staticmethod + def instantiate_basis( + n_basis, basis_class, class_specific_params, window_size=10, **kwargs + ): + """Instantiate and return two basis of the type specified.""" + + # Set non-optional args + default_kwargs = { + "n_basis_funcs": n_basis, + "window_size": window_size, + "decay_rates": np.arange(1, 1 + n_basis), + } + repeated_keys = set(default_kwargs.keys()).intersection(kwargs.keys()) + if repeated_keys: + raise ValueError( + "Cannot set `n_basis_funcs, window_size, decay_rates` with kwargs" + ) + + # Merge with provided extra kwargs + kwargs = {**default_kwargs, **kwargs} + + if basis_class == AdditiveBasis: + kwargs_mspline = inspect_utils.trim_kwargs( + basis.MSplineEval, kwargs, class_specific_params + ) + kwargs_raised_cosine = inspect_utils.trim_kwargs( + basis.RaisedCosineLinearConv, kwargs, class_specific_params + ) + b1 = basis.MSplineEval(**kwargs_mspline) + b2 = basis.RaisedCosineLinearConv(**kwargs_raised_cosine) + basis_obj = b1 + b2 + elif basis_class == MultiplicativeBasis: + kwargs_mspline = inspect_utils.trim_kwargs( + basis.MSplineEval, kwargs, class_specific_params + ) + kwargs_raised_cosine = inspect_utils.trim_kwargs( + basis.RaisedCosineLinearConv, kwargs, class_specific_params + ) + b1 = basis.MSplineEval(**kwargs_mspline) + b2 = basis.RaisedCosineLinearConv(**kwargs_raised_cosine) + basis_obj = b1 * b2 + else: + basis_obj = basis_class( + **inspect_utils.trim_kwargs(basis_class, kwargs, class_specific_params) + ) + return basis_obj + + +# automatic define user accessible basis and check the methods +def list_all_basis_classes(filter_basis="all") -> list[type]: + """ + Return all the classes in nemos.basis which are a subclass of Basis, + which should be all concrete classes except TransformerBasis. + """ + all_basis = [ + class_obj + for _, class_obj in inspect_utils.get_non_abstract_classes(basis) + if issubclass(class_obj, Basis) + ] + [ + bas + for _, bas in inspect_utils.get_non_abstract_classes(nmo.basis._basis) + if bas != basis.TransformerBasis + ] + if filter_basis != "all": + all_basis = [a for a in all_basis if filter_basis in a.__name__] + return all_basis + + # Sample subclass to test instantiation and methods class MockRegressor(nmo.base_regressor.BaseRegressor): """ diff --git a/tests/test_basis.py b/tests/test_basis.py index d6168bc8..86c58fa2 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -1,4 +1,3 @@ -import abc import inspect import itertools import pickle @@ -11,9 +10,8 @@ import numpy as np import pynapple as nap import pytest -from sklearn.base import clone as sk_clone +from conftest import BasisFuncsTesting, CombinedBasis, list_all_basis_classes -import nemos as nmo import nemos._inspect_utils as inspect_utils import nemos.basis.basis as basis import nemos.convolve as convolve @@ -34,33 +32,6 @@ def extra_decay_rates(cls, n_basis): return {} -# automatic define user accessible basis and check the methods -def list_all_basis_classes(filter_basis="all") -> list[type]: - """ - Return all the classes in nemos.basis which are a subclass of Basis, - which should be all concrete classes except TransformerBasis. - """ - all_basis = [ - class_obj - for _, class_obj in inspect_utils.get_non_abstract_classes(basis) - if issubclass(class_obj, Basis) - ] + [ - bas - for _, bas in inspect_utils.get_non_abstract_classes(nmo.basis._basis) - if bas != basis.TransformerBasis - ] - if filter_basis != "all": - all_basis = [a for a in all_basis if filter_basis in a.__name__] - return all_basis - - -@pytest.fixture() -def class_specific_params(): - """Returns all the params for each class.""" - all_cls = list_all_basis_classes("Conv") + list_all_basis_classes("Eval") - return {cls.__name__: cls._get_param_names() for cls in all_cls} - - def test_all_basis_are_tested() -> None: """Meta-test. @@ -137,11 +108,11 @@ def test_all_basis_are_tested() -> None: ], ) def test_example_docstrings_add( - basis_cls, method_name, descr_match, class_specific_params + basis_cls, method_name, descr_match, basis_class_specific_params ): basis_instance = CombinedBasis().instantiate_basis( - 5, basis_cls, class_specific_params, window_size=10 + 5, basis_cls, basis_class_specific_params, window_size=10 ) method = getattr(basis_instance, method_name) doc = method.__doc__ @@ -285,19 +256,6 @@ def test_expected_output_split_by_feature(basis_instance, super_class): np.testing.assert_array_equal(xx[~nans], x[~nans]) -class BasisFuncsTesting(abc.ABC): - """ - An abstract base class that sets the foundation for individual basis function testing. - This class requires an implementation of a 'cls' method, which is utilized by the meta-test - that verifies if all basis functions are properly tested. - """ - - @abc.abstractmethod - def cls(self): - pass - - -# Auto-generated file with stripped classes and shared methods @pytest.mark.parametrize( "cls", [ @@ -343,12 +301,44 @@ def test_call_vmin_vmax(self, samples, vmin, vmax, expectation, cls): with expectation: bas._evaluate(samples) + @pytest.mark.parametrize("n_basis", [5, 6]) + @pytest.mark.parametrize("vmin, vmax", [(0, 1), (-1, 1)]) + @pytest.mark.parametrize("inp_num", [1, 2]) + def test_sklearn_clone_eval(self, cls, n_basis, vmin, vmax, inp_num): + bas = cls["eval"]( + n_basis, bounds=(vmin, vmax), **extra_decay_rates(cls["eval"], n_basis) + ) + bas.set_input_shape(inp_num) + bas2 = bas.__sklearn_clone__() + assert id(bas) != id(bas2) + assert np.all( + bas.__dict__.pop("decay_rates", True) + == bas2.__dict__.pop("decay_rates", True) + ) + assert bas.__dict__ == bas2.__dict__ + + @pytest.mark.parametrize("n_basis", [5, 6]) + @pytest.mark.parametrize("ws", [10, 20]) + @pytest.mark.parametrize("inp_num", [1, 2]) + def test_sklearn_clone_conv(self, cls, n_basis, ws, inp_num): + bas = cls["conv"]( + n_basis, window_size=ws, **extra_decay_rates(cls["eval"], n_basis) + ) + bas.set_input_shape(inp_num) + bas2 = bas.__sklearn_clone__() + assert id(bas) != id(bas2) + assert np.all( + bas.__dict__.pop("decay_rates", True) + == bas2.__dict__.pop("decay_rates", True) + ) + assert bas.__dict__ == bas2.__dict__ + @pytest.mark.parametrize( "attribute, value", [ ("label", None), ("label", "label"), - ("n_basis_input", 1), + ("n_basis_input_", 1), ("n_output_features", 5), ], ) @@ -442,10 +432,10 @@ def test_set_num_basis_input(self, n_input, cls): bas = cls["conv"]( n_basis_funcs=5, window_size=10, **extra_decay_rates(cls["conv"], 5) ) - assert bas.n_basis_input is None + assert bas.n_basis_input_ is None bas.compute_features(np.random.randn(20, n_input)) - assert bas.n_basis_input == (n_input,) - assert bas._n_basis_input == (n_input,) + assert bas.n_basis_input_ == (n_input,) + assert bas._n_basis_input_ == (n_input,) @pytest.mark.parametrize( "bounds, samples, nan_idx, mn, mx", @@ -520,7 +510,7 @@ def test_vmin_vmax_init(self, bounds, expectation, cls): @pytest.mark.parametrize("n_basis", [6, 7]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_call_basis_number(self, n_basis, mode, kwargs, cls): @@ -552,7 +542,7 @@ def test_call_equivalent_in_conv(self, n_basis, cls): ], ) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) @pytest.mark.parametrize("n_basis", [6]) def test_call_input_num(self, num_input, n_basis, mode, kwargs, expectation, cls): @@ -572,7 +562,7 @@ def test_call_input_num(self, num_input, n_basis, mode, kwargs, expectation, cls ) @pytest.mark.parametrize("n_basis", [6]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_call_input_shape(self, inp, mode, kwargs, expectation, n_basis, cls): bas = cls[mode]( @@ -586,7 +576,7 @@ def test_call_input_shape(self, inp, mode, kwargs, expectation, n_basis, cls): @pytest.mark.parametrize("n_basis", [6]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_call_nan_location(self, mode, kwargs, n_basis, cls): bas = cls[mode]( @@ -619,7 +609,7 @@ def test_call_input_type(self, samples, expectation, n_basis, cls): bas._evaluate(samples) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_call_nan(self, mode, kwargs, cls): bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) @@ -629,7 +619,7 @@ def test_call_nan(self, mode, kwargs, cls): @pytest.mark.parametrize("n_basis", [6, 7]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_call_non_empty(self, n_basis, mode, kwargs, cls): bas = cls[mode]( @@ -640,7 +630,7 @@ def test_call_non_empty(self, n_basis, mode, kwargs, cls): @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_call_sample_axis(self, time_axis_shape, mode, kwargs, cls): bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) @@ -657,7 +647,7 @@ def test_call_sample_axis(self, time_axis_shape, mode, kwargs, cls): ], ) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_call_sample_range(self, mn, mx, expectation, mode, kwargs, cls): bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) @@ -722,7 +712,7 @@ def test_compute_features_conv_input( order, width, cls, - class_specific_params, + basis_class_specific_params, ): x = np.ones(input_shape) @@ -737,7 +727,9 @@ def test_compute_features_conv_input( ) # figure out which kwargs needs to be removed - kwargs = inspect_utils.trim_kwargs(cls["conv"], kwargs, class_specific_params) + kwargs = inspect_utils.trim_kwargs( + cls["conv"], kwargs, basis_class_specific_params + ) basis_obj = cls["conv"](**kwargs) out = basis_obj.compute_features(x) @@ -913,7 +905,7 @@ def test_convolution_is_performed(self, cls): @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_evaluate_on_grid_basis_size(self, sample_size, mode, kwargs, cls): if "OrthExp" in cls["eval"].__name__: @@ -932,7 +924,7 @@ def test_evaluate_on_grid_basis_size(self, sample_size, mode, kwargs, cls): @pytest.mark.parametrize("n_input", [0, 1, 2]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_evaluate_on_grid_input_number(self, n_input, mode, kwargs, cls): basis_obj = cls[mode]( @@ -957,7 +949,7 @@ def test_evaluate_on_grid_input_number(self, n_input, mode, kwargs, cls): @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_evaluate_on_grid_meshgrid_size(self, sample_size, mode, kwargs, cls): if "OrthExp" in cls["eval"].__name__: @@ -991,7 +983,7 @@ def test_fit_kernel_shape(self, cls): @pytest.mark.parametrize( "mode, ws, expectation", [ - ("conv", 2, does_not_raise()), + ("conv", 5, does_not_raise()), ( "conv", -1, @@ -1036,9 +1028,9 @@ def test_init_window_size(self, mode, ws, expectation, cls): n_basis_funcs=5, window_size=ws, **extra_decay_rates(cls[mode], 5) ) - @pytest.mark.parametrize("samples", [[], [0], [0, 0]]) + @pytest.mark.parametrize("samples", [[], [0] * 10, [0] * 11]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_non_empty_samples(self, samples, mode, kwargs, cls): if "OrthExp" in cls["eval"].__name__: @@ -1083,7 +1075,7 @@ def test_number_of_required_inputs_compute_features( basis_obj.compute_features(*inputs) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_pynapple_support(self, mode, kwargs, cls): bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) @@ -1181,7 +1173,7 @@ def test_set_params( decay_rates, conv_kwargs, cls, - class_specific_params, + basis_class_specific_params, ): """Test the read-only and read/write property of the parameters.""" pars = dict( @@ -1198,7 +1190,7 @@ def test_set_params( pars = { key: value for key, value in pars.items() - if key in class_specific_params[cls[mode].__name__] + if key in basis_class_specific_params[cls[mode].__name__] } keys = list(pars.keys()) @@ -1242,15 +1234,16 @@ def test_set_window_size(self, mode, expectation, cls): def test_transform_fails(self, cls): bas = cls["conv"]( - n_basis_funcs=5, window_size=3, **extra_decay_rates(cls["conv"], 5) + n_basis_funcs=5, window_size=5, **extra_decay_rates(cls["conv"], 5) ) with pytest.raises( - ValueError, match="You must call `_set_kernel` before `_compute_features`" + ValueError, match="You must call `setup_basis` before `_compute_features`" ): bas._compute_features(np.linspace(0, 1, 10)) def test_transformer_get_params(self, cls): bas = cls["eval"](n_basis_funcs=5, **extra_decay_rates(cls["eval"], 5)) + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) bas_transformer = bas.to_transformer() params_transf = bas_transformer.get_params() params_transf.pop("_basis") @@ -1355,7 +1348,7 @@ def test_decay_to_zero_basis_number_match(self, width): @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_minimum_number_of_basis_required_is_matched( self, n_basis_funcs, mode, kwargs @@ -1466,7 +1459,7 @@ def test_time_scaling_values(self, time_scaling, expectation, mode, kwargs): ], ) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_width_values(self, width, expectation, mode, kwargs): with expectation: @@ -1478,7 +1471,7 @@ class TestRaisedCosineLinearBasis(BasisFuncsTesting): @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_minimum_number_of_basis_required_is_matched( self, n_basis_funcs, mode, kwargs @@ -1553,7 +1546,7 @@ class TestMSplineBasis(BasisFuncsTesting): @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize("order", [-1, 0, 1, 2, 3, 4, 5]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_minimum_number_of_basis_required_is_matched( self, n_basis_funcs, order, mode, kwargs @@ -1654,6 +1647,50 @@ def test_vmin_vmax_eval_on_grid_scaling_effect_on_eval( class TestOrthExponentialBasis(BasisFuncsTesting): cls = {"eval": basis.OrthExponentialEval, "conv": basis.OrthExponentialConv} + @pytest.mark.parametrize( + "window_size, n_basis, expectation", + [ + ( + 4, + 5, + pytest.raises( + ValueError, + match="OrthExponentialConv basis requires at least a window_size", + ), + ), + (5, 5, does_not_raise()), + ], + ) + def test_window_size_at_init(self, window_size, n_basis, expectation): + decay_rates = np.asarray(np.arange(1, n_basis + 1), dtype=float) + with expectation: + self.cls["conv"](n_basis, decay_rates=decay_rates, window_size=window_size) + + @pytest.mark.parametrize( + "window_size, n_basis, expectation", + [ + ( + 4, + 5, + pytest.raises( + ValueError, + match="OrthExponentialConv basis requires at least a window_size", + ), + ), + (5, 5, does_not_raise()), + ], + ) + def test_window_size_at_init(self, window_size, n_basis, expectation): + decay_rates = np.asarray(np.arange(1, n_basis + 1), dtype=float) + obj = self.cls["conv"]( + n_basis, decay_rates=decay_rates, window_size=n_basis + 1 + ) + with expectation: + obj.window_size = window_size + + with expectation: + obj.set_params(window_size=window_size) + @pytest.mark.parametrize( "decay_rates", [[1, 2, 3], [0.01, 0.02, 0.001], [2, 1, 1, 2.4]] ) @@ -1729,7 +1766,7 @@ class TestBSplineBasis(BasisFuncsTesting): @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize("order", [1, 2, 3, 4, 5]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_minimum_number_of_basis_required_is_matched( self, n_basis_funcs, order, mode, kwargs @@ -1817,7 +1854,7 @@ class TestCyclicBSplineBasis(BasisFuncsTesting): @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize("order", [2, 3, 4, 5]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_minimum_number_of_basis_required_is_matched( self, n_basis_funcs, order, mode, kwargs @@ -1918,72 +1955,16 @@ def test_samples_range_matches_compute_features_requirements( basis_obj.compute_features(np.linspace(*sample_range, 100)) -class CombinedBasis(BasisFuncsTesting): - """ - This class is used to run tests on combination operations (e.g., addition, multiplication) among Basis functions. - - Properties: - - cls: Class (default = None) - """ - - cls = None - - @staticmethod - def instantiate_basis( - n_basis, basis_class, class_specific_params, window_size=10, **kwargs - ): - """Instantiate and return two basis of the type specified.""" - - # Set non-optional args - default_kwargs = { - "n_basis_funcs": n_basis, - "window_size": window_size, - "decay_rates": np.arange(1, 1 + n_basis), - } - repeated_keys = set(default_kwargs.keys()).intersection(kwargs.keys()) - if repeated_keys: - raise ValueError( - "Cannot set `n_basis_funcs, window_size, decay_rates` with kwargs" - ) - - # Merge with provided extra kwargs - kwargs = {**default_kwargs, **kwargs} - - if basis_class == AdditiveBasis: - kwargs_mspline = inspect_utils.trim_kwargs( - basis.MSplineEval, kwargs, class_specific_params - ) - kwargs_raised_cosine = inspect_utils.trim_kwargs( - basis.RaisedCosineLinearConv, kwargs, class_specific_params - ) - b1 = basis.MSplineEval(**kwargs_mspline) - b2 = basis.RaisedCosineLinearConv(**kwargs_raised_cosine) - basis_obj = b1 + b2 - elif basis_class == MultiplicativeBasis: - kwargs_mspline = inspect_utils.trim_kwargs( - basis.MSplineEval, kwargs, class_specific_params - ) - kwargs_raised_cosine = inspect_utils.trim_kwargs( - basis.RaisedCosineLinearConv, kwargs, class_specific_params - ) - b1 = basis.MSplineEval(**kwargs_mspline) - b2 = basis.RaisedCosineLinearConv(**kwargs_raised_cosine) - basis_obj = b1 * b2 - else: - basis_obj = basis_class( - **inspect_utils.trim_kwargs(basis_class, kwargs, class_specific_params) - ) - return basis_obj - - class TestAdditiveBasis(CombinedBasis): cls = {"eval": AdditiveBasis, "conv": AdditiveBasis} @pytest.mark.parametrize("samples", [[[0], []], [[], [0]], [[0, 0], [0, 0]]]) @pytest.mark.parametrize("base_cls", [basis.BSplineEval, basis.BSplineConv]) - def test_non_empty_samples(self, base_cls, samples, class_specific_params): + def test_non_empty_samples(self, base_cls, samples, basis_class_specific_params): kwargs = {"window_size": 2, "n_basis_funcs": 5} - kwargs = inspect_utils.trim_kwargs(base_cls, kwargs, class_specific_params) + kwargs = inspect_utils.trim_kwargs( + base_cls, kwargs, basis_class_specific_params + ) basis_obj = base_cls(**kwargs) + base_cls(**kwargs) if any(tuple(len(s) == 0 for s in samples)): with pytest.raises( @@ -2010,6 +1991,68 @@ def test_compute_features_input(self, eval_input): basis_obj = basis.MSplineEval(5) + basis.MSplineEval(5) basis_obj.compute_features(*eval_input) + @pytest.mark.parametrize("n_basis_a", [6]) + @pytest.mark.parametrize("n_basis_b", [5]) + @pytest.mark.parametrize("vmin, vmax", [(-1, 1)]) + @pytest.mark.parametrize("basis_a", list_all_basis_classes()) + @pytest.mark.parametrize("basis_b", list_all_basis_classes()) + @pytest.mark.parametrize("inp_num", [1, 2]) + def test_sklearn_clone( + self, + basis_a, + basis_b, + n_basis_a, + n_basis_b, + vmin, + vmax, + inp_num, + basis_class_specific_params, + ): + """Recursively check cloning.""" + basis_a_obj = self.instantiate_basis( + n_basis_a, basis_a, basis_class_specific_params, window_size=10 + ) + basis_a_obj = basis_a_obj.set_input_shape( + *([inp_num] * basis_a_obj._n_input_dimensionality) + ) + basis_b_obj = self.instantiate_basis( + n_basis_b, basis_b, basis_class_specific_params, window_size=15 + ) + basis_b_obj = basis_b_obj.set_input_shape( + *([inp_num] * basis_b_obj._n_input_dimensionality) + ) + add = basis_a_obj + basis_b_obj + + def filter_attributes(obj, exclude_keys): + return { + key: val for key, val in obj.__dict__.items() if key not in exclude_keys + } + + def compare(b1, b2): + assert id(b1) != id(b2) + assert b1.__class__.__name__ == b2.__class__.__name__ + if hasattr(b1, "basis1"): + compare(b1.basis1, b2.basis1) + compare(b1.basis2, b2.basis2) + # add all params that are not parent or _basis1,_basis2 + d1 = filter_attributes( + b1, exclude_keys=["_basis1", "_basis2", "_parent"] + ) + d2 = filter_attributes( + b2, exclude_keys=["_basis1", "_basis2", "_parent"] + ) + assert d1 == d2 + else: + decay_rates_b1 = b1.__dict__.get("_decay_rates", -1) + decay_rates_b2 = b2.__dict__.get("_decay_rates", -1) + assert np.array_equal(decay_rates_b1, decay_rates_b2) + d1 = filter_attributes(b1, exclude_keys=["_decay_rates", "_parent"]) + d2 = filter_attributes(b2, exclude_keys=["_decay_rates", "_parent"]) + assert d1 == d2 + + add2 = add.__sklearn_clone__() + compare(add, add2) + @pytest.mark.parametrize("n_basis_a", [5, 6]) @pytest.mark.parametrize("n_basis_b", [5, 6]) @pytest.mark.parametrize("sample_size", [10, 1000]) @@ -2024,7 +2067,7 @@ def test_compute_features_returns_expected_number_of_basis( basis_a, basis_b, window_size, - class_specific_params, + basis_class_specific_params, ): """ Test whether the evaluation of the `AdditiveBasis` results in a number of basis @@ -2032,10 +2075,10 @@ def test_compute_features_returns_expected_number_of_basis( """ # define the two basis basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj @@ -2063,16 +2106,16 @@ def test_sample_size_of_compute_features_matches_that_of_input( basis_a, basis_b, window_size, - class_specific_params, + basis_class_specific_params, ): """ Test whether the output sample size from `AdditiveBasis` compute_features function matches input sample size. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj eval_basis = basis_obj.compute_features( @@ -2100,17 +2143,17 @@ def test_number_of_required_inputs_compute_features( basis_a, basis_b, window_size, - class_specific_params, + basis_class_specific_params, ): """ Test whether the number of required inputs for the `compute_features` function matches the sum of the number of input samples from the two bases. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj required_dim = ( @@ -2132,16 +2175,22 @@ def test_number_of_required_inputs_compute_features( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_meshgrid_size( - self, sample_size, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, + sample_size, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + basis_class_specific_params, ): """ Test whether the resulting meshgrid size matches the sample size input. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) basis_obj = basis_a_obj + basis_b_obj res = basis_obj.evaluate_on_grid( @@ -2156,16 +2205,22 @@ def test_evaluate_on_grid_meshgrid_size( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_basis_size( - self, sample_size, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, + sample_size, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + basis_class_specific_params, ): """ Test whether the number sample size output by evaluate_on_grid matches the sample size of the input. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) basis_obj = basis_a_obj + basis_b_obj eval_basis = basis_obj.evaluate_on_grid( @@ -2179,17 +2234,23 @@ def test_evaluate_on_grid_basis_size( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_input_number( - self, n_input, basis_a, basis_b, n_basis_a, n_basis_b, class_specific_params + self, + n_input, + basis_a, + basis_b, + n_basis_a, + n_basis_b, + basis_class_specific_params, ): """ Test whether the number of inputs provided to `evaluate_on_grid` matches the sum of the number of input samples required from each of the basis objects. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) basis_obj = basis_a_obj + basis_b_obj inputs = [20] * n_input @@ -2211,7 +2272,13 @@ def test_evaluate_on_grid_input_number( @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) def test_pynapple_support_compute_features( - self, basis_a, basis_b, n_basis_a, n_basis_b, sample_size, class_specific_params + self, + basis_a, + basis_b, + n_basis_a, + n_basis_b, + sample_size, + basis_class_specific_params, ): iset = nap.IntervalSet(start=[0, 0.5], end=[0.49999, 1]) inp = nap.Tsd( @@ -2220,9 +2287,9 @@ def test_pynapple_support_compute_features( time_support=iset, ) basis_add = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) + self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) # compute_features the basis over pynapple Tsd objects out = basis_add.compute_features(*([inp] * basis_add._n_input_dimensionality)) @@ -2237,7 +2304,7 @@ def test_pynapple_support_compute_features( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) @pytest.mark.parametrize("num_input", [0, 1, 2, 3, 4, 5]) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) def test_call_input_num( self, n_basis_a, @@ -2246,13 +2313,13 @@ def test_call_input_num( basis_b, num_input, window_size, - class_specific_params, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj if num_input == basis_obj._n_input_dimensionality: @@ -2271,7 +2338,7 @@ def test_call_input_num( (np.linspace(0, 1, 10)[:, None], pytest.raises(ValueError)), ], ) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -2285,20 +2352,20 @@ def test_call_input_shape( inp, window_size, expectation, - class_specific_params, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj with expectation: basis_obj._evaluate(*([inp] * basis_obj._n_input_dimensionality)) @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -2311,25 +2378,31 @@ def test_call_sample_axis( basis_b, time_axis_shape, window_size, - class_specific_params, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj inp = [np.linspace(0, 1, time_axis_shape)] * basis_obj._n_input_dimensionality assert basis_obj._evaluate(*inp).shape[0] == time_axis_shape - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_nan( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): if ( basis_a == basis.OrthExponentialBasis @@ -2337,10 +2410,10 @@ def test_call_nan( ): return basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj inp = [np.linspace(0, 1, 10)] * basis_obj._n_input_dimensionality @@ -2353,40 +2426,46 @@ def test_call_nan( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_equivalent_in_conv( - self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, basis_class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=3 + n_basis_a, basis_a, basis_class_specific_params, window_size=9 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=3 + n_basis_b, basis_b, basis_class_specific_params, window_size=9 ) bas_eva = basis_a_obj + basis_b_obj basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=8 + n_basis_a, basis_a, basis_class_specific_params, window_size=8 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=8 + n_basis_b, basis_b, basis_class_specific_params, window_size=8 ) bas_con = basis_a_obj + basis_b_obj x = [np.linspace(0, 1, 10)] * bas_con._n_input_dimensionality assert np.all(bas_con._evaluate(*x) == bas_eva._evaluate(*x)) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_pynapple_support( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj + basis_b_obj x = np.linspace(0, 1, 10) @@ -2398,19 +2477,25 @@ def test_pynapple_support( assert np.all(y == y_nap.d) assert np.all(y_nap.t == x_nap[0].t) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [6, 7]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_basis_number( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj + basis_b_obj x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality @@ -2419,19 +2504,25 @@ def test_call_basis_number( == basis_a_obj.n_basis_funcs + basis_b_obj.n_basis_funcs ) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_non_empty( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj + basis_b_obj with pytest.raises(ValueError, match="All sample provided must"): @@ -2445,7 +2536,7 @@ def test_call_non_empty( (0.1, 2, does_not_raise()), ], ) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -2460,7 +2551,7 @@ def test_call_sample_range( mx, expectation, window_size, - class_specific_params, + basis_class_specific_params, ): if expectation == "check": if ( @@ -2473,10 +2564,10 @@ def test_call_sample_range( else: expectation = does_not_raise() basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj + basis_b_obj with expectation: @@ -2487,16 +2578,16 @@ def test_call_sample_range( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_fit_kernel( - self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, basis_class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) bas = basis_a_obj + basis_b_obj - bas.set_kernel() + bas.setup_basis(*([np.ones(10)] * bas._n_input_dimensionality)) def check_kernel(basis_obj): has_kern = [] @@ -2516,13 +2607,13 @@ def check_kernel(basis_obj): @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_transform_fails( - self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, basis_class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) bas = basis_a_obj + basis_b_obj if "Eval" in basis_a.__name__ and "Eval" in basis_b.__name__: @@ -2530,7 +2621,7 @@ def test_transform_fails( else: context = pytest.raises( ValueError, - match="You must call `_set_kernel` before `_compute_features`", + match="You must call `setup_basis` before `_compute_features`", ) with context: x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality @@ -2554,11 +2645,11 @@ def test_set_num_basis_input(self, n_basis_input1, n_basis_input2): bas1 = basis.RaisedCosineLinearConv(10, window_size=10) bas2 = basis.BSplineConv(10, window_size=10) bas_add = bas1 + bas2 - assert bas_add.n_basis_input is None + assert bas_add.n_basis_input_ is None bas_add.compute_features( np.ones((20, n_basis_input1)), np.ones((20, n_basis_input2)) ) - assert bas_add.n_basis_input == (n_basis_input1, n_basis_input2) + assert bas_add.n_basis_input_ == (n_basis_input1, n_basis_input2) @pytest.mark.parametrize( "n_input, expectation", @@ -2594,16 +2685,16 @@ def test_set_input_shape_type_1d_arrays( basis_b, shape_a, shape_b, - class_specific_params, + basis_class_specific_params, add_shape_a, add_shape_b, ): x = (np.ones((10, *add_shape_a)), np.ones((10, *add_shape_b))) basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) add = basis_a + basis_b @@ -2633,16 +2724,16 @@ def test_set_input_shape_type_2d_arrays( basis_b, shape_a, shape_b, - class_specific_params, + basis_class_specific_params, add_shape_a, add_shape_b, ): x = (np.ones((10, 2, *add_shape_a)), np.ones((10, 3, *add_shape_b))) basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) add = basis_a + basis_b @@ -2672,16 +2763,16 @@ def test_set_input_shape_type_nd_arrays( basis_b, shape_a, shape_b, - class_specific_params, + basis_class_specific_params, add_shape_a, add_shape_b, ): x = (np.ones((10, 2, 2, *add_shape_a)), np.ones((10, 3, 1, *add_shape_b))) basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) add = basis_a + basis_b @@ -2720,18 +2811,103 @@ def test_set_input_shape_type_nd_arrays( "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") ) def test_set_input_value_types( - self, inp_shape, expectation, basis_a, basis_b, class_specific_params + self, inp_shape, expectation, basis_a, basis_b, basis_class_specific_params ): basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) add = basis_a + basis_b with expectation: add.set_input_shape(*inp_shape) + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + def test_deep_copy_basis(self, basis_a, basis_b, basis_class_specific_params): + basis_a = self.instantiate_basis( + 5, basis_a, basis_class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, basis_class_specific_params, window_size=10 + ) + add = basis_a + basis_b + # test pointing to different objects + assert id(add.basis1) != id(basis_a) + assert id(add.basis1) != id(basis_b) + assert id(add.basis2) != id(basis_a) + assert id(add.basis2) != id(basis_b) + + # test attributes are not related + basis_a.n_basis_funcs = 10 + basis_b.n_basis_funcs = 10 + assert add.basis1.n_basis_funcs == 5 + assert add.basis2.n_basis_funcs == 5 + + add.basis1.n_basis_funcs = 6 + add.basis2.n_basis_funcs = 6 + assert basis_a.n_basis_funcs == 10 + assert basis_b.n_basis_funcs == 10 + + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + def test_compute_n_basis_runtime( + self, basis_a, basis_b, basis_class_specific_params + ): + basis_a = self.instantiate_basis( + 5, basis_a, basis_class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, basis_class_specific_params, window_size=10 + ) + add = basis_a + basis_b + add.basis1.n_basis_funcs = 10 + assert add.n_basis_funcs == 15 + add.basis2.n_basis_funcs = 10 + assert add.n_basis_funcs == 20 + + @pytest.mark.parametrize("basis_a", list_all_basis_classes()) + @pytest.mark.parametrize("basis_b", list_all_basis_classes()) + def test_runtime_n_basis_out_compute( + self, basis_a, basis_b, basis_class_specific_params + ): + basis_a = self.instantiate_basis( + 5, basis_a, basis_class_specific_params, window_size=10 + ) + basis_a.set_input_shape( + *([1] * basis_a._n_input_dimensionality) + ).to_transformer() + basis_b = self.instantiate_basis( + 5, basis_b, basis_class_specific_params, window_size=10 + ) + basis_b.set_input_shape( + *([1] * basis_b._n_input_dimensionality) + ).to_transformer() + add = basis_a + basis_b + inps_a = [2] * basis_a._n_input_dimensionality + add.basis1.set_input_shape(*inps_a) + if isinstance(basis_a, MultiplicativeBasis): + new_out_num = np.prod(inps_a) * add.basis1.n_basis_funcs + else: + new_out_num = inps_a[0] * add.basis1.n_basis_funcs + assert add.n_output_features == new_out_num + add.basis2.n_basis_funcs + inps_b = [3] * basis_b._n_input_dimensionality + if isinstance(basis_b, MultiplicativeBasis): + new_out_num_b = np.prod(inps_b) * add.basis2.n_basis_funcs + else: + new_out_num_b = inps_b[0] * add.basis2.n_basis_funcs + add.basis2.set_input_shape(*inps_b) + assert add.n_output_features == new_out_num + new_out_num_b + class TestMultiplicativeBasis(CombinedBasis): cls = {"eval": MultiplicativeBasis, "conv": MultiplicativeBasis} @@ -2781,7 +2957,7 @@ def test_compute_features_returns_expected_number_of_basis( basis_a, basis_b, window_size, - class_specific_params, + basis_class_specific_params, ): """ Test whether the evaluation of the `MultiplicativeBasis` results in a number of basis @@ -2789,10 +2965,10 @@ def test_compute_features_returns_expected_number_of_basis( """ # define the two basis basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj @@ -2821,17 +2997,17 @@ def test_sample_size_of_compute_features_matches_that_of_input( basis_a, basis_b, window_size, - class_specific_params, + basis_class_specific_params, ): """ Test whether the output sample size from the `MultiplicativeBasis` fit_transform function matches the input sample size. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj eval_basis = basis_obj.compute_features( @@ -2858,17 +3034,17 @@ def test_number_of_required_inputs_compute_features( basis_a, basis_b, window_size, - class_specific_params, + basis_class_specific_params, ): """ Test whether the number of required inputs for the `compute_features` function matches the sum of the number of input samples from the two bases. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj required_dim = ( @@ -2890,16 +3066,22 @@ def test_number_of_required_inputs_compute_features( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_meshgrid_size( - self, sample_size, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, + sample_size, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + basis_class_specific_params, ): """ Test whether the resulting meshgrid size matches the sample size input. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) basis_obj = basis_a_obj * basis_b_obj res = basis_obj.evaluate_on_grid( @@ -2914,16 +3096,22 @@ def test_evaluate_on_grid_meshgrid_size( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_basis_size( - self, sample_size, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, + sample_size, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + basis_class_specific_params, ): """ Test whether the number sample size output by evaluate_on_grid matches the sample size of the input. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) basis_obj = basis_a_obj * basis_b_obj eval_basis = basis_obj.evaluate_on_grid( @@ -2937,17 +3125,23 @@ def test_evaluate_on_grid_basis_size( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_input_number( - self, n_input, basis_a, basis_b, n_basis_a, n_basis_b, class_specific_params + self, + n_input, + basis_a, + basis_b, + n_basis_a, + n_basis_b, + basis_class_specific_params, ): """ Test whether the number of inputs provided to `evaluate_on_grid` matches the sum of the number of input samples required from each of the basis objects. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) basis_obj = basis_a_obj * basis_b_obj inputs = [20] * n_input @@ -2977,15 +3171,15 @@ def test_inconsistent_sample_sizes( n_basis_b, sample_size_a, sample_size_b, - class_specific_params, + basis_class_specific_params, ): """Test that the inputs of inconsistent sample sizes result in an exception when compute_features is called""" raise_exception = sample_size_a != sample_size_b basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) input_a = [ np.linspace(0, 1, sample_size_a) @@ -3009,7 +3203,13 @@ def test_inconsistent_sample_sizes( @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) def test_pynapple_support_compute_features( - self, basis_a, basis_b, n_basis_a, n_basis_b, sample_size, class_specific_params + self, + basis_a, + basis_b, + n_basis_a, + n_basis_b, + sample_size, + basis_class_specific_params, ): iset = nap.IntervalSet(start=[0, 0.5], end=[0.49999, 1]) inp = nap.Tsd( @@ -3018,9 +3218,9 @@ def test_pynapple_support_compute_features( time_support=iset, ) basis_prod = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) * self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) out = basis_prod.compute_features(*([inp] * basis_prod._n_input_dimensionality)) assert isinstance(out, nap.TsdFrame) @@ -3032,7 +3232,7 @@ def test_pynapple_support_compute_features( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) @pytest.mark.parametrize("num_input", [0, 1, 2, 3, 4, 5]) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) def test_call_input_num( self, n_basis_a, @@ -3041,13 +3241,13 @@ def test_call_input_num( basis_b, num_input, window_size, - class_specific_params, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj if num_input == basis_obj._n_input_dimensionality: @@ -3066,7 +3266,7 @@ def test_call_input_num( (np.linspace(0, 1, 10)[:, None], pytest.raises(ValueError)), ], ) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -3080,20 +3280,20 @@ def test_call_input_shape( inp, window_size, expectation, - class_specific_params, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj with expectation: basis_obj._evaluate(*([inp] * basis_obj._n_input_dimensionality)) @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -3106,25 +3306,31 @@ def test_call_sample_axis( basis_b, time_axis_shape, window_size, - class_specific_params, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj inp = [np.linspace(0, 1, time_axis_shape)] * basis_obj._n_input_dimensionality assert basis_obj._evaluate(*inp).shape[0] == time_axis_shape - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_nan( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): if ( basis_a == basis.OrthExponentialBasis @@ -3132,10 +3338,10 @@ def test_call_nan( ): return basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj inp = [np.linspace(0, 1, 10)] * basis_obj._n_input_dimensionality @@ -3148,40 +3354,46 @@ def test_call_nan( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_equivalent_in_conv( - self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, basis_class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) bas_eva = basis_a_obj * basis_b_obj basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=8 + n_basis_a, basis_a, basis_class_specific_params, window_size=8 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=8 + n_basis_b, basis_b, basis_class_specific_params, window_size=8 ) bas_con = basis_a_obj * basis_b_obj x = [np.linspace(0, 1, 10)] * bas_con._n_input_dimensionality assert np.all(bas_con._evaluate(*x) == bas_eva._evaluate(*x)) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_pynapple_support( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj * basis_b_obj x = np.linspace(0, 1, 10) @@ -3193,19 +3405,25 @@ def test_pynapple_support( assert np.all(y == y_nap.d) assert np.all(y_nap.t == x_nap[0].t) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [6, 7]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_basis_number( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj * basis_b_obj x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality @@ -3214,19 +3432,25 @@ def test_call_basis_number( == basis_a_obj.n_basis_funcs * basis_b_obj.n_basis_funcs ) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_non_empty( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj * basis_b_obj with pytest.raises(ValueError, match="All sample provided must"): @@ -3240,7 +3464,7 @@ def test_call_non_empty( (0.1, 2, does_not_raise()), ], ) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -3255,7 +3479,7 @@ def test_call_sample_range( mx, expectation, window_size, - class_specific_params, + basis_class_specific_params, ): if expectation == "check": if ( @@ -3268,10 +3492,10 @@ def test_call_sample_range( else: expectation = does_not_raise() basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj * basis_b_obj with expectation: @@ -3282,16 +3506,16 @@ def test_call_sample_range( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_fit_kernel( - self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, basis_class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) bas = basis_a_obj * basis_b_obj - bas.set_kernel() + bas._set_input_independent_states() def check_kernel(basis_obj): has_kern = [] @@ -3311,13 +3535,13 @@ def check_kernel(basis_obj): @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_transform_fails( - self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, basis_class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) bas = basis_a_obj * basis_b_obj if "Eval" in basis_a.__name__ and "Eval" in basis_b.__name__: @@ -3325,7 +3549,7 @@ def test_transform_fails( else: context = pytest.raises( ValueError, - match="You must call `_set_kernel` before `_compute_features`", + match="You must call `setup_basis` before `_compute_features`", ) with context: x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality @@ -3349,11 +3573,11 @@ def test_set_num_basis_input(self, n_basis_input1, n_basis_input2): bas1 = basis.RaisedCosineLinearConv(10, window_size=10) bas2 = basis.BSplineConv(10, window_size=10) bas_add = bas1 * bas2 - assert bas_add.n_basis_input is None + assert bas_add.n_basis_input_ is None bas_add.compute_features( np.ones((20, n_basis_input1)), np.ones((20, n_basis_input2)) ) - assert bas_add.n_basis_input == (n_basis_input1, n_basis_input2) + assert bas_add.n_basis_input_ == (n_basis_input1, n_basis_input2) @pytest.mark.parametrize( "n_input, expectation", @@ -3375,14 +3599,14 @@ def test_expected_input_number(self, n_input, expectation): @pytest.mark.parametrize("n_basis_input1", [1, 2, 3]) @pytest.mark.parametrize("n_basis_input2", [1, 2, 3]) - def test_n_basis_input(self, n_basis_input1, n_basis_input2): + def test_n_basis_input_(self, n_basis_input1, n_basis_input2): bas1 = basis.RaisedCosineLinearConv(10, window_size=10) bas2 = basis.BSplineConv(10, window_size=10) bas_prod = bas1 * bas2 bas_prod.compute_features( np.ones((20, n_basis_input1)), np.ones((20, n_basis_input2)) ) - assert bas_prod.n_basis_input == (n_basis_input1, n_basis_input2) + assert bas_prod.n_basis_input_ == (n_basis_input1, n_basis_input2) @pytest.mark.parametrize( "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") @@ -3400,16 +3624,16 @@ def test_set_input_shape_type_1d_arrays( basis_b, shape_a, shape_b, - class_specific_params, + basis_class_specific_params, add_shape_a, add_shape_b, ): x = (np.ones((10, *add_shape_a)), np.ones((10, *add_shape_b))) basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) mul = basis_a * basis_b @@ -3439,16 +3663,16 @@ def test_set_input_shape_type_2d_arrays( basis_b, shape_a, shape_b, - class_specific_params, + basis_class_specific_params, add_shape_a, add_shape_b, ): x = (np.ones((10, 2, *add_shape_a)), np.ones((10, 3, *add_shape_b))) basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) mul = basis_a * basis_b @@ -3478,16 +3702,16 @@ def test_set_input_shape_type_nd_arrays( basis_b, shape_a, shape_b, - class_specific_params, + basis_class_specific_params, add_shape_a, add_shape_b, ): x = (np.ones((10, 2, 2, *add_shape_a)), np.ones((10, 3, 1, *add_shape_b))) basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) mul = basis_a * basis_b @@ -3526,24 +3750,109 @@ def test_set_input_shape_type_nd_arrays( "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") ) def test_set_input_value_types( - self, inp_shape, expectation, basis_a, basis_b, class_specific_params + self, inp_shape, expectation, basis_a, basis_b, basis_class_specific_params ): basis_a = self.instantiate_basis( - 5, basis_a, class_specific_params, window_size=10 + 5, basis_a, basis_class_specific_params, window_size=10 ) basis_b = self.instantiate_basis( - 5, basis_b, class_specific_params, window_size=10 + 5, basis_b, basis_class_specific_params, window_size=10 ) mul = basis_a * basis_b with expectation: mul.set_input_shape(*inp_shape) + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + def test_deep_copy_basis(self, basis_a, basis_b, basis_class_specific_params): + basis_a = self.instantiate_basis( + 5, basis_a, basis_class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, basis_class_specific_params, window_size=10 + ) + mul = basis_a * basis_b + # test pointing to different objects + assert id(mul.basis1) != id(basis_a) + assert id(mul.basis1) != id(basis_b) + assert id(mul.basis2) != id(basis_a) + assert id(mul.basis2) != id(basis_b) + + # test attributes are not related + basis_a.n_basis_funcs = 10 + basis_b.n_basis_funcs = 10 + assert mul.basis1.n_basis_funcs == 5 + assert mul.basis2.n_basis_funcs == 5 + + mul.basis1.n_basis_funcs = 6 + mul.basis2.n_basis_funcs = 6 + assert basis_a.n_basis_funcs == 10 + assert basis_b.n_basis_funcs == 10 + + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + def test_compute_n_basis_runtime( + self, basis_a, basis_b, basis_class_specific_params + ): + basis_a = self.instantiate_basis( + 5, basis_a, basis_class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, basis_class_specific_params, window_size=10 + ) + mul = basis_a * basis_b + mul.basis1.n_basis_funcs = 10 + assert mul.n_basis_funcs == 50 + mul.basis2.n_basis_funcs = 10 + assert mul.n_basis_funcs == 100 + + @pytest.mark.parametrize("basis_a", list_all_basis_classes()) + @pytest.mark.parametrize("basis_b", list_all_basis_classes()) + def test_runtime_n_basis_out_compute( + self, basis_a, basis_b, basis_class_specific_params + ): + basis_a = self.instantiate_basis( + 5, basis_a, basis_class_specific_params, window_size=10 + ) + basis_a.set_input_shape( + *([1] * basis_a._n_input_dimensionality) + ).to_transformer() + basis_b = self.instantiate_basis( + 5, basis_b, basis_class_specific_params, window_size=10 + ) + basis_b.set_input_shape( + *([1] * basis_b._n_input_dimensionality) + ).to_transformer() + mul = basis_a * basis_b + inps_a = [2] * basis_a._n_input_dimensionality + mul.basis1.set_input_shape(*inps_a) + if isinstance(basis_a, MultiplicativeBasis): + new_out_num = np.prod(inps_a) * mul.basis1.n_basis_funcs + else: + new_out_num = inps_a[0] * mul.basis1.n_basis_funcs + assert mul.n_output_features == new_out_num * mul.basis2.n_basis_funcs + inps_b = [3] * basis_b._n_input_dimensionality + if isinstance(basis_b, MultiplicativeBasis): + new_out_num_b = np.prod(inps_b) * mul.basis2.n_basis_funcs + else: + new_out_num_b = inps_b[0] * mul.basis2.n_basis_funcs + mul.basis2.set_input_shape(*inps_b) + assert mul.n_output_features == new_out_num * new_out_num_b + @pytest.mark.parametrize( "exponent", [-1, 0, 0.5, basis.RaisedCosineLogEval(4), 1, 2, 3] ) @pytest.mark.parametrize("basis_class", list_all_basis_classes()) -def test_power_of_basis(exponent, basis_class, class_specific_params): +def test_power_of_basis(exponent, basis_class, basis_class_specific_params): """Test if the power behaves as expected.""" raise_exception_type = not type(exponent) is int @@ -3553,7 +3862,7 @@ def test_power_of_basis(exponent, basis_class, class_specific_params): raise_exception_value = False basis_obj = CombinedBasis.instantiate_basis( - 5, basis_class, class_specific_params, window_size=10 + 5, basis_class, basis_class_specific_params, window_size=10 ) if raise_exception_type: @@ -3589,13 +3898,14 @@ def test_power_of_basis(exponent, basis_class, class_specific_params): "basis_cls", list_all_basis_classes(), ) -def test_basis_to_transformer(basis_cls, class_specific_params): +def test_basis_to_transformer(basis_cls, basis_class_specific_params): n_basis_funcs = 5 bas = CombinedBasis().instantiate_basis( - n_basis_funcs, basis_cls, class_specific_params, window_size=10 + n_basis_funcs, basis_cls, basis_class_specific_params, window_size=10 ) - - trans_bas = bas.to_transformer() + trans_bas = bas.set_input_shape( + *([1] * bas._n_input_dimensionality) + ).to_transformer() assert isinstance(trans_bas, basis.TransformerBasis) @@ -3607,386 +3917,6 @@ def test_basis_to_transformer(basis_cls, class_specific_params): assert np.all(getattr(bas, k) == getattr(trans_bas, k)) -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -def test_transformer_has_the_same_public_attributes_as_basis( - basis_cls, class_specific_params -): - n_basis_funcs = 5 - bas = CombinedBasis().instantiate_basis( - n_basis_funcs, basis_cls, class_specific_params, window_size=10 - ) - - public_attrs_basis = {attr for attr in dir(bas) if not attr.startswith("_")} - public_attrs_transformerbasis = { - attr for attr in dir(bas.to_transformer()) if not attr.startswith("_") - } - - assert public_attrs_transformerbasis - public_attrs_basis == { - "fit", - "fit_transform", - "transform", - } - - assert public_attrs_basis - public_attrs_transformerbasis == set() - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), -) -def test_to_transformer_and_constructor_are_equivalent( - basis_cls, class_specific_params -): - n_basis_funcs = 5 - bas = CombinedBasis().instantiate_basis( - n_basis_funcs, basis_cls, class_specific_params, window_size=10 - ) - - trans_bas_a = bas.to_transformer() - trans_bas_b = basis.TransformerBasis(bas) - - # they both just have a _basis - assert ( - list(trans_bas_a.__dict__.keys()) - == list(trans_bas_b.__dict__.keys()) - == ["_basis"] - ) - # and those bases are the same - assert np.all( - trans_bas_a._basis.__dict__.pop("_decay_rates", 1) - == trans_bas_b._basis.__dict__.pop("_decay_rates", 1) - ) - assert trans_bas_a._basis.__dict__ == trans_bas_b._basis.__dict__ - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -def test_basis_to_transformer_makes_a_copy(basis_cls, class_specific_params): - bas_a = CombinedBasis().instantiate_basis( - 5, basis_cls, class_specific_params, window_size=10 - ) - trans_bas_a = bas_a.to_transformer() - - # changing an attribute in bas should not change trans_bas - if basis_cls in [AdditiveBasis, MultiplicativeBasis]: - bas_a._basis1.n_basis_funcs = 10 - assert trans_bas_a._basis._basis1.n_basis_funcs == 5 - - # changing an attribute in the transformer basis should not change the original - bas_b = CombinedBasis().instantiate_basis( - 5, basis_cls, class_specific_params, window_size=10 - ) - trans_bas_b = bas_b.to_transformer() - trans_bas_b._basis._basis1.n_basis_funcs = 100 - assert bas_b._basis1.n_basis_funcs == 5 - else: - bas_a.n_basis_funcs = 10 - assert trans_bas_a.n_basis_funcs == 5 - - # changing an attribute in the transformer basis should not change the original - bas_b = CombinedBasis().instantiate_basis( - 5, basis_cls, class_specific_params, window_size=10 - ) - trans_bas_b = bas_b.to_transformer() - trans_bas_b.n_basis_funcs = 100 - assert bas_b.n_basis_funcs == 5 - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -@pytest.mark.parametrize("n_basis_funcs", [5, 10, 20]) -def test_transformerbasis_getattr(basis_cls, n_basis_funcs, class_specific_params): - trans_basis = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - n_basis_funcs, basis_cls, class_specific_params, window_size=10 - ) - ) - if basis_cls in [AdditiveBasis, MultiplicativeBasis]: - for bas in [ - getattr(trans_basis._basis, attr) for attr in ("_basis1", "_basis2") - ]: - assert bas.n_basis_funcs == n_basis_funcs - else: - assert trans_basis.n_basis_funcs == n_basis_funcs - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), -) -@pytest.mark.parametrize("n_basis_funcs_init", [5]) -@pytest.mark.parametrize("n_basis_funcs_new", [6, 10, 20]) -def test_transformerbasis_set_params( - basis_cls, n_basis_funcs_init, n_basis_funcs_new, class_specific_params -): - trans_basis = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - n_basis_funcs_init, basis_cls, class_specific_params, window_size=10 - ) - ) - trans_basis.set_params(n_basis_funcs=n_basis_funcs_new) - - assert trans_basis.n_basis_funcs == n_basis_funcs_new - assert trans_basis._basis.n_basis_funcs == n_basis_funcs_new - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), -) -def test_transformerbasis_setattr_basis(basis_cls, class_specific_params): - # setting the _basis attribute should change it - trans_bas = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - 10, basis_cls, class_specific_params, window_size=10 - ) - ) - trans_bas._basis = CombinedBasis().instantiate_basis( - 20, basis_cls, class_specific_params, window_size=10 - ) - - assert trans_bas.n_basis_funcs == 20 - assert trans_bas._basis.n_basis_funcs == 20 - assert isinstance(trans_bas._basis, basis_cls) - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), -) -def test_transformerbasis_setattr_basis_attribute(basis_cls, class_specific_params): - # setting an attribute that is an attribute of the underlying _basis - # should propagate setting it on _basis itself - trans_bas = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - 10, basis_cls, class_specific_params, window_size=10 - ) - ) - trans_bas.n_basis_funcs = 20 - - assert trans_bas.n_basis_funcs == 20 - assert trans_bas._basis.n_basis_funcs == 20 - assert isinstance(trans_bas._basis, basis_cls) - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), -) -def test_transformerbasis_copy_basis_on_contsruct(basis_cls, class_specific_params): - # modifying the transformerbasis's attributes shouldn't - # touch the original basis that was used to create it - orig_bas = CombinedBasis().instantiate_basis( - 10, basis_cls, class_specific_params, window_size=10 - ) - trans_bas = basis.TransformerBasis(orig_bas) - trans_bas.n_basis_funcs = 20 - - assert orig_bas.n_basis_funcs == 10 - assert trans_bas._basis.n_basis_funcs == 20 - assert trans_bas._basis.n_basis_funcs == 20 - assert isinstance(trans_bas._basis, basis_cls) - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -def test_transformerbasis_setattr_illegal_attribute(basis_cls, class_specific_params): - # changing an attribute that is not _basis or an attribute of _basis - # is not allowed - trans_bas = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - 10, basis_cls, class_specific_params, window_size=10 - ) - ) - - with pytest.raises( - ValueError, - match="Only setting _basis or existing attributes of _basis is allowed.", - ): - trans_bas.random_attr = "random value" - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -def test_transformerbasis_addition(basis_cls, class_specific_params): - n_basis_funcs_a = 5 - n_basis_funcs_b = n_basis_funcs_a * 2 - bas_a = CombinedBasis().instantiate_basis( - n_basis_funcs_a, basis_cls, class_specific_params, window_size=10 - ) - bas_b = CombinedBasis().instantiate_basis( - n_basis_funcs_b, basis_cls, class_specific_params, window_size=10 - ) - trans_bas_a = basis.TransformerBasis(bas_a) - trans_bas_b = basis.TransformerBasis(bas_b) - trans_bas_sum = trans_bas_a + trans_bas_b - assert isinstance(trans_bas_sum, basis.TransformerBasis) - assert isinstance(trans_bas_sum._basis, AdditiveBasis) - assert ( - trans_bas_sum.n_basis_funcs - == trans_bas_a.n_basis_funcs + trans_bas_b.n_basis_funcs - ) - assert ( - trans_bas_sum._n_input_dimensionality - == trans_bas_a._n_input_dimensionality + trans_bas_b._n_input_dimensionality - ) - if basis_cls not in [AdditiveBasis, MultiplicativeBasis]: - assert trans_bas_sum._basis1.n_basis_funcs == n_basis_funcs_a - assert trans_bas_sum._basis2.n_basis_funcs == n_basis_funcs_b - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -def test_transformerbasis_multiplication(basis_cls, class_specific_params): - n_basis_funcs_a = 5 - n_basis_funcs_b = n_basis_funcs_a * 2 - trans_bas_a = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - n_basis_funcs_a, basis_cls, class_specific_params, window_size=10 - ) - ) - trans_bas_b = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - n_basis_funcs_b, basis_cls, class_specific_params, window_size=10 - ) - ) - trans_bas_prod = trans_bas_a * trans_bas_b - assert isinstance(trans_bas_prod, basis.TransformerBasis) - assert isinstance(trans_bas_prod._basis, MultiplicativeBasis) - assert ( - trans_bas_prod.n_basis_funcs - == trans_bas_a.n_basis_funcs * trans_bas_b.n_basis_funcs - ) - assert ( - trans_bas_prod._n_input_dimensionality - == trans_bas_a._n_input_dimensionality + trans_bas_b._n_input_dimensionality - ) - if basis_cls not in [AdditiveBasis, MultiplicativeBasis]: - assert trans_bas_prod._basis1.n_basis_funcs == n_basis_funcs_a - assert trans_bas_prod._basis2.n_basis_funcs == n_basis_funcs_b - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -@pytest.mark.parametrize( - "exponent, error_type, error_message", - [ - (2, does_not_raise, None), - (5, does_not_raise, None), - (0.5, TypeError, "Exponent should be an integer"), - (-1, ValueError, "Exponent should be a non-negative integer"), - ], -) -def test_transformerbasis_exponentiation( - basis_cls, exponent: int, error_type, error_message, class_specific_params -): - trans_bas = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - 5, basis_cls, class_specific_params, window_size=10 - ) - ) - - if not isinstance(exponent, int): - with pytest.raises(error_type, match=error_message): - trans_bas_exp = trans_bas**exponent - assert isinstance(trans_bas_exp, basis.TransformerBasis) - assert isinstance(trans_bas_exp._basis, MultiplicativeBasis) - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -def test_transformerbasis_dir(basis_cls, class_specific_params): - trans_bas = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - 5, basis_cls, class_specific_params, window_size=10 - ) - ) - for attr_name in ( - "fit", - "transform", - "fit_transform", - "n_basis_funcs", - "mode", - "window_size", - ): - if ( - attr_name == "window_size" - and "Conv" not in trans_bas._basis.__class__.__name__ - ): - continue - assert attr_name in dir(trans_bas) - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes("Conv"), -) -def test_transformerbasis_sk_clone_kernel_noned(basis_cls, class_specific_params): - orig_bas = CombinedBasis().instantiate_basis( - 10, basis_cls, class_specific_params, window_size=20 - ) - trans_bas = basis.TransformerBasis(orig_bas) - - # kernel should be saved in the object after fit - trans_bas.fit(np.random.randn(100, 20)) - assert isinstance(trans_bas.kernel_, np.ndarray) - - # cloning should set kernel_ to None - trans_bas_clone = sk_clone(trans_bas) - - # the original object should still have kernel_ - assert isinstance(trans_bas.kernel_, np.ndarray) - # but the clone should not have one - assert trans_bas_clone.kernel_ is None - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -@pytest.mark.parametrize("n_basis_funcs", [5]) -def test_transformerbasis_pickle( - tmpdir, basis_cls, n_basis_funcs, class_specific_params -): - # the test that tries cross-validation with n_jobs = 2 already should test this - trans_bas = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - n_basis_funcs, basis_cls, class_specific_params, window_size=10 - ) - ) - filepath = tmpdir / "transformerbasis.pickle" - with open(filepath, "wb") as f: - pickle.dump(trans_bas, f) - with open(filepath, "rb") as f: - trans_bas2 = pickle.load(f) - - assert isinstance(trans_bas2, basis.TransformerBasis) - if basis_cls in [AdditiveBasis, MultiplicativeBasis]: - for bas in [ - getattr(trans_bas2._basis, attr) for attr in ("_basis1", "_basis2") - ]: - assert bas.n_basis_funcs == n_basis_funcs - else: - assert trans_bas2.n_basis_funcs == n_basis_funcs - - @pytest.mark.parametrize( "tsd", [ @@ -4025,7 +3955,7 @@ def test_multi_epoch_pynapple_basis( shift, predictor_causality, nan_index, - class_specific_params, + basis_class_specific_params, ): """Test nan location in multi-epoch pynapple tsd.""" kwargs = dict( @@ -4039,7 +3969,11 @@ def test_multi_epoch_pynapple_basis( else: nbasis = 5 bas = CombinedBasis().instantiate_basis( - nbasis, basis_cls, class_specific_params, window_size=window_size, **kwargs + nbasis, + basis_cls, + basis_class_specific_params, + window_size=window_size, + **kwargs, ) n_input = bas._n_input_dimensionality @@ -4092,7 +4026,7 @@ def test_multi_epoch_pynapple_basis_transformer( shift, predictor_causality, nan_index, - class_specific_params, + basis_class_specific_params, ): """Test nan location in multi-epoch pynapple tsd.""" kwargs = dict( @@ -4106,18 +4040,22 @@ def test_multi_epoch_pynapple_basis_transformer( nbasis = 5 bas = CombinedBasis().instantiate_basis( - nbasis, basis_cls, class_specific_params, window_size=window_size, **kwargs + nbasis, + basis_cls, + basis_class_specific_params, + window_size=window_size, + **kwargs, ) n_input = bas._n_input_dimensionality - # pass through transformer - bas = basis.TransformerBasis(bas) - # concat input X = pynapple_concatenate_numpy([tsd[:, None]] * n_input, axis=1) # run convolutions + # pass through transformer + bas.set_input_shape(X) + bas = basis.TransformerBasis(bas) res = bas.fit_transform(X) # check nans @@ -4140,18 +4078,18 @@ def test_multi_epoch_pynapple_basis_transformer( "__add__", "__add__", lambda bas1, bas2, bas3: { - "1": slice(0, bas1._n_basis_input[0] * bas1.n_basis_funcs), + "1": slice(0, bas1._n_basis_input_[0] * bas1.n_basis_funcs), "2": slice( - bas1._n_basis_input[0] * bas1.n_basis_funcs, - bas1._n_basis_input[0] * bas1.n_basis_funcs - + bas2._n_basis_input[0] * bas2.n_basis_funcs, + bas1._n_basis_input_[0] * bas1.n_basis_funcs, + bas1._n_basis_input_[0] * bas1.n_basis_funcs + + bas2._n_basis_input_[0] * bas2.n_basis_funcs, ), "3": slice( - bas1._n_basis_input[0] * bas1.n_basis_funcs - + bas2._n_basis_input[0] * bas2.n_basis_funcs, - bas1._n_basis_input[0] * bas1.n_basis_funcs - + bas2._n_basis_input[0] * bas2.n_basis_funcs - + bas3._n_basis_input[0] * bas3.n_basis_funcs, + bas1._n_basis_input_[0] * bas1.n_basis_funcs + + bas2._n_basis_input_[0] * bas2.n_basis_funcs, + bas1._n_basis_input_[0] * bas1.n_basis_funcs + + bas2._n_basis_input_[0] * bas2.n_basis_funcs + + bas3._n_basis_input_[0] * bas3.n_basis_funcs, ), }, ), @@ -4159,13 +4097,13 @@ def test_multi_epoch_pynapple_basis_transformer( "__add__", "__mul__", lambda bas1, bas2, bas3: { - "1": slice(0, bas1._n_basis_input[0] * bas1.n_basis_funcs), + "1": slice(0, bas1._n_basis_input_[0] * bas1.n_basis_funcs), "(2 * 3)": slice( - bas1._n_basis_input[0] * bas1.n_basis_funcs, - bas1._n_basis_input[0] * bas1.n_basis_funcs - + bas2._n_basis_input[0] + bas1._n_basis_input_[0] * bas1.n_basis_funcs, + bas1._n_basis_input_[0] * bas1.n_basis_funcs + + bas2._n_basis_input_[0] * bas2.n_basis_funcs - * bas3._n_basis_input[0] + * bas3._n_basis_input_[0] * bas3.n_basis_funcs, ), }, @@ -4177,11 +4115,11 @@ def test_multi_epoch_pynapple_basis_transformer( # note that it doesn't respect algebra order but execute right to left (first add then multiplies) "(1 * (2 + 3))": slice( 0, - bas1._n_basis_input[0] + bas1._n_basis_input_[0] * bas1.n_basis_funcs * ( - bas2._n_basis_input[0] * bas2.n_basis_funcs - + bas3._n_basis_input[0] * bas3.n_basis_funcs + bas2._n_basis_input_[0] * bas2.n_basis_funcs + + bas3._n_basis_input_[0] * bas3.n_basis_funcs ), ), }, @@ -4192,11 +4130,11 @@ def test_multi_epoch_pynapple_basis_transformer( lambda bas1, bas2, bas3: { "(1 * (2 * 3))": slice( 0, - bas1._n_basis_input[0] + bas1._n_basis_input_[0] * bas1.n_basis_funcs - * bas2._n_basis_input[0] + * bas2._n_basis_input_[0] * bas2.n_basis_funcs - * bas3._n_basis_input[0] + * bas3._n_basis_input_[0] * bas3.n_basis_funcs, ), }, @@ -4204,7 +4142,7 @@ def test_multi_epoch_pynapple_basis_transformer( ], ) def test__get_splitter( - bas1, bas2, bas3, operator1, operator2, compute_slice, class_specific_params + bas1, bas2, bas3, operator1, operator2, compute_slice, basis_class_specific_params ): # skip nested if any( @@ -4218,13 +4156,22 @@ def test__get_splitter( combine_basis = CombinedBasis() bas1_instance = combine_basis.instantiate_basis( - n_basis[0], bas1, class_specific_params, window_size=10, label="1" + n_basis[0], bas1, basis_class_specific_params, window_size=10, label="1" + ) + bas1_instance.set_input_shape( + *([n_input_basis[0]] * bas1_instance._n_input_dimensionality) ) bas2_instance = combine_basis.instantiate_basis( - n_basis[1], bas2, class_specific_params, window_size=10, label="2" + n_basis[1], bas2, basis_class_specific_params, window_size=10, label="2" + ) + bas2_instance.set_input_shape( + *([n_input_basis[1]] * bas2_instance._n_input_dimensionality) ) bas3_instance = combine_basis.instantiate_basis( - n_basis[2], bas3, class_specific_params, window_size=10, label="3" + n_basis[2], bas3, basis_class_specific_params, window_size=10, label="3" + ) + bas3_instance.set_input_shape( + *([n_input_basis[2]] * bas3_instance._n_input_dimensionality) ) func1 = getattr(bas1_instance, operator1) @@ -4250,11 +4197,11 @@ def test__get_splitter( 1, 1, lambda bas1, bas2: { - "1": slice(0, bas1._n_basis_input[0] * bas1.n_basis_funcs), + "1": slice(0, bas1._n_basis_input_[0] * bas1.n_basis_funcs), "2": slice( - bas1._n_basis_input[0] * bas1.n_basis_funcs, - bas1._n_basis_input[0] * bas1.n_basis_funcs - + bas2._n_basis_input[0] * bas2.n_basis_funcs, + bas1._n_basis_input_[0] * bas1.n_basis_funcs, + bas1._n_basis_input_[0] * bas1.n_basis_funcs + + bas2._n_basis_input_[0] * bas2.n_basis_funcs, ), }, ), @@ -4265,9 +4212,9 @@ def test__get_splitter( lambda bas1, bas2: { "(1 * 2)": slice( 0, - bas1._n_basis_input[0] + bas1._n_basis_input_[0] * bas1.n_basis_funcs - * bas2._n_basis_input[0] + * bas2._n_basis_input_[0] * bas2.n_basis_funcs, ) }, @@ -4292,7 +4239,7 @@ def test__get_splitter( 1, lambda bas1, bas2: { "(1 * 2)": slice( - 0, bas1._n_basis_input[0] * bas1.n_basis_funcs * bas2.n_basis_funcs + 0, bas1._n_basis_input_[0] * bas1.n_basis_funcs * bas2.n_basis_funcs ) }, ), @@ -4319,7 +4266,7 @@ def test__get_splitter( 2, lambda bas1, bas2: { "(1 * 2)": slice( - 0, bas2._n_basis_input[0] * bas1.n_basis_funcs * bas2.n_basis_funcs + 0, bas2._n_basis_input_[0] * bas1.n_basis_funcs * bas2.n_basis_funcs ) }, ), @@ -4361,7 +4308,7 @@ def test__get_splitter_split_by_input( n_input_basis_1, n_input_basis_2, compute_slice, - class_specific_params, + basis_class_specific_params, ): # skip nested if any( @@ -4373,10 +4320,17 @@ def test__get_splitter_split_by_input( n_basis = [5, 6] combine_basis = CombinedBasis() bas1_instance = combine_basis.instantiate_basis( - n_basis[0], bas1, class_specific_params, window_size=10, label="1" + n_basis[0], bas1, basis_class_specific_params, window_size=10, label="1" ) + bas1_instance.set_input_shape( + *([n_input_basis_1] * bas1_instance._n_input_dimensionality) + ) + bas2_instance = combine_basis.instantiate_basis( - n_basis[1], bas2, class_specific_params, window_size=10, label="2" + n_basis[1], bas2, basis_class_specific_params, window_size=10, label="2" + ) + bas2_instance.set_input_shape( + *([n_input_basis_2] * bas1_instance._n_input_dimensionality) ) func1 = getattr(bas1_instance, operator) @@ -4396,7 +4350,7 @@ def test__get_splitter_split_by_input( "bas1, bas2, bas3", list(itertools.product(*[list_all_basis_classes()] * 3)), ) -def test_duplicate_keys(bas1, bas2, bas3, class_specific_params): +def test_duplicate_keys(bas1, bas2, bas3, basis_class_specific_params): # skip nested if any( bas in (AdditiveBasis, MultiplicativeBasis, basis.TransformerBasis) @@ -4406,13 +4360,13 @@ def test_duplicate_keys(bas1, bas2, bas3, class_specific_params): combine_basis = CombinedBasis() bas1_instance = combine_basis.instantiate_basis( - 5, bas1, class_specific_params, window_size=10, label="label" + 5, bas1, basis_class_specific_params, window_size=10, label="label" ) bas2_instance = combine_basis.instantiate_basis( - 5, bas2, class_specific_params, window_size=10, label="label" + 5, bas2, basis_class_specific_params, window_size=10, label="label" ) bas3_instance = combine_basis.instantiate_basis( - 5, bas3, class_specific_params, window_size=10, label="label" + 5, bas3, basis_class_specific_params, window_size=10, label="label" ) bas_obj = bas1_instance + bas2_instance + bas3_instance @@ -4443,7 +4397,7 @@ def test_duplicate_keys(bas1, bas2, bas3, class_specific_params): ], ) def test_split_feature_axis( - bas1, bas2, x, axis, expectation, exp_shapes, class_specific_params + bas1, bas2, x, axis, expectation, exp_shapes, basis_class_specific_params ): # skip nested if any( @@ -4455,10 +4409,10 @@ def test_split_feature_axis( n_basis = [5, 6] combine_basis = CombinedBasis() bas1_instance = combine_basis.instantiate_basis( - n_basis[0], bas1, class_specific_params, window_size=10, label="1" + n_basis[0], bas1, basis_class_specific_params, window_size=10, label="1" ) bas2_instance = combine_basis.instantiate_basis( - n_basis[1], bas2, class_specific_params, window_size=10, label="2" + n_basis[1], bas2, basis_class_specific_params, window_size=10, label="2" ) bas = bas1_instance + bas2_instance diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 5e4ce13d..9e52a4f2 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -21,7 +21,7 @@ ) def test_sklearn_transformer_pipeline(bas, poissonGLM_model_instantiation): X, y, model, _, _ = poissonGLM_model_instantiation - bas = TransformerBasis(bas) + bas = TransformerBasis(bas.set_input_shape(*([1] * bas._n_input_dimensionality))) pipe = pipeline.Pipeline([("eval", bas), ("fit", model)]) pipe.fit(X[:, : bas._basis._n_input_dimensionality] ** 2, y) @@ -39,7 +39,7 @@ def test_sklearn_transformer_pipeline(bas, poissonGLM_model_instantiation): ) def test_sklearn_transformer_pipeline_cv(bas, poissonGLM_model_instantiation): X, y, model, _, _ = poissonGLM_model_instantiation - bas = TransformerBasis(bas) + bas = TransformerBasis(bas.set_input_shape(*([1] * bas._n_input_dimensionality))) pipe = pipeline.Pipeline([("basis", bas), ("fit", model)]) param_grid = dict(basis__n_basis_funcs=(4, 5, 10)) gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score="raise") @@ -60,7 +60,7 @@ def test_sklearn_transformer_pipeline_cv_multiprocess( bas, poissonGLM_model_instantiation ): X, y, model, _, _ = poissonGLM_model_instantiation - bas = TransformerBasis(bas) + bas = TransformerBasis(bas.set_input_shape(*([1] * bas._n_input_dimensionality))) pipe = pipeline.Pipeline([("basis", bas), ("fit", model)]) param_grid = dict(basis__n_basis_funcs=(4, 5, 10)) gridsearch = GridSearchCV( @@ -86,8 +86,15 @@ def test_sklearn_transformer_pipeline_cv_directly_over_basis( ): X, y, model, _, _ = poissonGLM_model_instantiation bas = TransformerBasis(bas_cls(5)) + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) pipe = pipeline.Pipeline([("transformerbasis", bas), ("fit", model)]) - param_grid = dict(transformerbasis___basis=(bas_cls(5), bas_cls(10), bas_cls(20))) + param_grid = dict( + transformerbasis___basis=( + bas_cls(5).set_input_shape(*([1] * bas._n_input_dimensionality)), + bas_cls(10).set_input_shape(*([1] * bas._n_input_dimensionality)), + bas_cls(20).set_input_shape(*([1] * bas._n_input_dimensionality)), + ) + ) gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score="raise") gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y) @@ -107,6 +114,7 @@ def test_sklearn_transformer_pipeline_cv_illegal_combination( ): X, y, model, _, _ = poissonGLM_model_instantiation bas = TransformerBasis(bas_cls(5)) + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) pipe = pipeline.Pipeline([("transformerbasis", bas), ("fit", model)]) param_grid = dict( transformerbasis___basis=(bas_cls(5), bas_cls(10), bas_cls(20)), @@ -165,6 +173,7 @@ def test_sklearn_transformer_pipeline_pynapple( ep = nap.IntervalSet(start=[0, 20.5], end=[20, X.shape[0]]) X_nap = nap.TsdFrame(t=np.arange(X.shape[0]), d=X, time_support=ep) y_nap = nap.Tsd(t=np.arange(X.shape[0]), d=y, time_support=ep) + bas = bas.set_input_shape(*([1] * bas._n_input_dimensionality)) bas = TransformerBasis(bas) # fit a pipeline & predict from pynapple pipe = pipeline.Pipeline([("eval", bas), ("fit", model)])