Skip to content

Commit

Permalink
fixed inheritance and removed TransformerMixin init call
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Dec 9, 2024
1 parent 2b05994 commit 4d4c70f
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 56 deletions.
122 changes: 81 additions & 41 deletions src/nemos/basis/_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import jax
import numpy as np
from numpy.typing import ArrayLike, NDArray
from pynapple import Tsd, TsdFrame
from pynapple import Tsd, TsdFrame, TsdTensor

from ..base_class import Base
from ..type_casting import support_pynapple
Expand Down Expand Up @@ -242,21 +242,21 @@ def add_constant(x):
return X

@check_transform_input
def compute_features(self, *xi: ArrayLike) -> FeatureMatrix:
def compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix:
"""
Apply the basis transformation to the input data.
This method is designed to be a high-level interface for transforming input
data using the basis functions defined by the subclass. Depending on the basis'
mode ('eval' or 'conv'), it either evaluates the basis functions at the sample
mode ('Eval' or 'Conv'), it either evaluates the basis functions at the sample
points or performs a convolution operation between the input data and the
basis functions.
Parameters
----------
*xi :
Input data arrays to be transformed. The shape and content requirements
depend on the subclass and mode of operation ('eval' or 'conv').
depend on the subclass and mode of operation ('Eval' or 'Conv').
Returns
-------
Expand All @@ -276,7 +276,7 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix:
return self._compute_features(*xi)

@abc.abstractmethod
def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix:
def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix:
"""Convolve or evaluate the basis."""
pass

Expand All @@ -286,7 +286,7 @@ def _set_kernel(self):
pass

@abc.abstractmethod
def _evaluate(self, *xi: ArrayLike) -> FeatureMatrix:
def _evaluate(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix:
"""
Abstract method to evaluate the basis functions at given points.
Expand Down Expand Up @@ -579,37 +579,14 @@ def _get_feature_slicing(
n_inputs = n_inputs or self._n_basis_input
start_slice = start_slice or 0

# If the instance is of AdditiveBasis type, handle slicing for the additive components
if isinstance(self, AdditiveBasis):
split_dict, start_slice = self._basis1._get_feature_slicing(
n_inputs[: len(self._basis1._n_basis_input)],
start_slice,
split_by_input=split_by_input,
)
sp2, start_slice = self._basis2._get_feature_slicing(
n_inputs[len(self._basis1._n_basis_input) :],
start_slice,
split_by_input=split_by_input,
)
split_dict = self._merge_slicing_dicts(split_dict, sp2)
else:
# Handle the default case for other basis types
split_dict, start_slice = self._get_default_slicing(
split_by_input, start_slice
)
# Handle the default case for non-additive basis types
# See overwritten method for recursion logic
split_dict, start_slice = self._get_default_slicing(
split_by_input=split_by_input, start_slice=start_slice
)

return split_dict, start_slice

def _merge_slicing_dicts(self, dict1: dict, dict2: dict) -> dict:
"""Merge two slicing dictionaries, handling key conflicts."""
for key, val in dict2.items():
if key in dict1:
new_key = self._generate_unique_key(dict1, key)
dict1[new_key] = val
else:
dict1[key] = val
return dict1

@staticmethod
def _generate_unique_key(existing_dict: dict, key: str) -> str:
"""Generate a unique key if there is a conflict."""
Expand Down Expand Up @@ -884,7 +861,7 @@ def _check_n_basis_min(self) -> None:
@support_pynapple(conv_type="numpy")
@check_transform_input
@check_one_dimensional
def _evaluate(self, *xi: ArrayLike) -> FeatureMatrix:
def _evaluate(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix:
"""
Evaluate the basis at the input samples.
Expand Down Expand Up @@ -924,7 +901,7 @@ def _evaluate(self, *xi: ArrayLike) -> FeatureMatrix:
return X

@add_docstring("compute_features", Basis)
def compute_features(self, *xi: ArrayLike) -> FeatureMatrix:
def compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix:
r"""
Examples
--------
Expand All @@ -941,7 +918,7 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix:
"""
return super().compute_features(*xi)

def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix:
def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix:
"""
Compute features for added bases and concatenate.
Expand Down Expand Up @@ -1159,6 +1136,70 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]:
"""
return super().evaluate_on_grid(*n_samples)

def _get_feature_slicing(
self,
n_inputs: Optional[tuple] = None,
start_slice: Optional[int] = None,
split_by_input: bool = True,
) -> Tuple[dict, int]:
"""
Calculate and return the slicing for features based on the input structure.
This method determines how to slice the features for different basis types.
Parameters
----------
n_inputs :
The number of input basis for each component, by default it uses ``self._n_basis_input``.
start_slice :
The starting index for slicing, by default it starts from 0.
split_by_input :
Flag indicating whether to split the slicing by individual inputs or not.
If ``False``, a single slice is generated for all inputs.
Returns
-------
split_dict :
Dictionary with keys as labels and values as slices representing
the slicing for each input or additive component, if split_by_input equals to
True or False respectively.
start_slice :
The updated starting index after slicing.
See Also
--------
_get_default_slicing : Handles default slicing logic.
_merge_slicing_dicts : Merges multiple slicing dictionaries, handling keys conflicts.
"""
# Set default values for n_inputs and start_slice if not provided
n_inputs = n_inputs or self._n_basis_input
start_slice = start_slice or 0

# If the instance is of AdditiveBasis type, handle slicing for the additive components

split_dict, start_slice = self._basis1._get_feature_slicing(
n_inputs[: len(self._basis1._n_basis_input)],
start_slice,
split_by_input=split_by_input,
)
sp2, start_slice = self._basis2._get_feature_slicing(
n_inputs[len(self._basis1._n_basis_input) :],
start_slice,
split_by_input=split_by_input,
)
split_dict = self._merge_slicing_dicts(split_dict, sp2)
return split_dict, start_slice

def _merge_slicing_dicts(self, dict1: dict, dict2: dict) -> dict:
"""Merge two slicing dictionaries, handling key conflicts."""
for key, val in dict2.items():
if key in dict1:
new_key = self._generate_unique_key(dict1, key)
dict1[new_key] = val
else:
dict1[key] = val
return dict1


class MultiplicativeBasis(Basis):
"""
Expand Down Expand Up @@ -1205,7 +1246,6 @@ def __init__(self, basis1: Basis, basis2: Basis) -> None:
self._label = "(" + basis1.label + " * " + basis2.label + ")"
self._basis1 = basis1
self._basis2 = basis2
BasisTransformerMixin.__init__(self)

def _check_n_basis_min(self) -> None:
pass
Expand All @@ -1232,7 +1272,7 @@ def _set_kernel(self, *xi: NDArray) -> Basis:
@support_pynapple(conv_type="numpy")
@check_transform_input
@check_one_dimensional
def _evaluate(self, *xi: ArrayLike) -> FeatureMatrix:
def _evaluate(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix:
"""
Evaluate the basis at the input samples.
Expand Down Expand Up @@ -1264,7 +1304,7 @@ def _evaluate(self, *xi: ArrayLike) -> FeatureMatrix:
)
return X

def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix:
def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix:
"""
Compute the features for the multiplied bases, and compute their outer product.
Expand Down Expand Up @@ -1357,7 +1397,7 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]:
return super().evaluate_on_grid(*n_samples)

@add_docstring("compute_features", Basis)
def compute_features(self, *xi: ArrayLike) -> FeatureMatrix:
def compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix:
"""
Examples
--------
Expand Down
11 changes: 6 additions & 5 deletions src/nemos/basis/_basis_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from typing import Optional, Tuple, Union

import numpy as np
from numpy.typing import ArrayLike
from numpy.typing import ArrayLike, NDArray
from pynapple import Tsd, TsdFrame, TsdTensor

from ..convolve import create_convolutional_predictor
from ._transformer_basis import TransformerBasis
Expand All @@ -19,12 +20,12 @@ class EvalBasisMixin:
def __init__(self, bounds: Optional[Tuple[float, float]] = None):
self.bounds = bounds

def _compute_features(self, *xi: ArrayLike):
def _compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor):
"""
Apply the basis transformation to the input data.
Evaluate basis at sample points.
The basis evaluated at the samples, or :math:`b_i(*xi)`, where :math:`b_i` is a
basis element. xi[k] must be a one-dimensional array or a pynapple Tsd.
basis element. xi[k] must be a one-dimensional array or a pynapple Tsd/TsdFrame/TsdTensor.
Parameters
----------
Expand Down Expand Up @@ -89,7 +90,7 @@ def __init__(self, window_size: int, conv_kwargs: Optional[dict] = None):
self.window_size = window_size
self.conv_kwargs = {} if conv_kwargs is None else conv_kwargs

def _compute_features(self, *xi: ArrayLike):
def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor):
"""
Convolve basis functions with input time series.
Expand Down
5 changes: 3 additions & 2 deletions src/nemos/basis/_decaying_exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numpy as np
import scipy.linalg
from numpy.typing import NDArray
from numpy.typing import ArrayLike, NDArray

from ..type_casting import support_pynapple
from ..typing import FeatureMatrix
Expand All @@ -19,6 +19,7 @@
min_max_rescale_samples,
)

from pynapple import Tsd, TsdFrame, TsdTensor

class OrthExponentialBasis(Basis, abc.ABC):
"""Set of 1D basis decaying exponential functions numerically orthogonalized.
Expand Down Expand Up @@ -134,7 +135,7 @@ def _check_sample_size(self, *sample_pts: NDArray) -> None:
@check_one_dimensional
def _evaluate(
self,
sample_pts: NDArray,
sample_pts: ArrayLike | Tsd | TsdFrame | TsdTensor,
) -> FeatureMatrix:
"""Generate basis functions with given spacing.
Expand Down
5 changes: 3 additions & 2 deletions src/nemos/basis/_raised_cosine_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
min_max_rescale_samples,
)

from pynapple import Tsd, TsdFrame, TsdTensor

class RaisedCosineBasisLinear(Basis, abc.ABC):
"""Represent linearly-spaced raised cosine basis functions.
Expand Down Expand Up @@ -101,7 +102,7 @@ def _check_width(width: float) -> None:
@check_one_dimensional
def _evaluate( # call these _evaluate
self,
sample_pts: ArrayLike,
sample_pts: ArrayLike | Tsd | TsdFrame | TsdTensor,
) -> FeatureMatrix:
"""Generate basis functions with given samples.
Expand Down Expand Up @@ -330,7 +331,7 @@ def _compute_peaks(self) -> NDArray:
@check_one_dimensional
def _evaluate(
self,
sample_pts: ArrayLike,
sample_pts: ArrayLike | Tsd | TsdFrame | TsdTensor,
) -> FeatureMatrix:
"""Generate log-spaced raised cosine basis with given samples.
Expand Down
7 changes: 4 additions & 3 deletions src/nemos/basis/_spline_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
min_max_rescale_samples,
)

from pynapple import Tsd, TsdFrame, TsdTensor

class SplineBasis(Basis, abc.ABC):
"""
Expand Down Expand Up @@ -217,7 +218,7 @@ def __init__(
@support_pynapple(conv_type="numpy")
@check_transform_input
@check_one_dimensional
def _evaluate(self, sample_pts: ArrayLike) -> FeatureMatrix:
def _evaluate(self, sample_pts: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix:
"""
Evaluate the M-spline basis functions at given sample points.
Expand Down Expand Up @@ -334,7 +335,7 @@ def __init__(
@support_pynapple(conv_type="numpy")
@check_transform_input
@check_one_dimensional
def _evaluate(self, sample_pts: ArrayLike) -> FeatureMatrix:
def _evaluate(self, sample_pts: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix:
"""
Evaluate the B-spline basis functions with given sample points.
Expand Down Expand Up @@ -445,7 +446,7 @@ def __init__(
@check_one_dimensional
def _evaluate(
self,
sample_pts: ArrayLike,
sample_pts: ArrayLike | Tsd | TsdFrame | TsdTensor,
) -> FeatureMatrix:
"""Evaluate the Cyclic B-spline basis functions with given sample points.
Expand Down
4 changes: 2 additions & 2 deletions src/nemos/basis/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:


class RaisedCosineLinearEval(
EvalBasisMixin, RaisedCosineBasisLinear, BasisTransformerMixin
EvalBasisMixin, RaisedCosineBasisLinear
):
"""
Represent linearly-spaced raised cosine basis functions.
Expand Down Expand Up @@ -911,7 +911,7 @@ def split_by_feature(


class RaisedCosineLinearConv(
ConvBasisMixin, RaisedCosineBasisLinear, BasisTransformerMixin
ConvBasisMixin, RaisedCosineBasisLinear
):
"""
Represent linearly-spaced raised cosine basis functions.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def test_all_basis_are_tested() -> None:
("evaluate_on_grid", "The number of points in the uniformly spaced grid"),
(
"compute_features",
"Apply the basis transformation to the input data|Convolve basis functions with input time series",
"Apply the basis transformation to the input data|Convolve basis functions with input time series|Evaluate basis at sample points",
),
(
"split_by_feature",
Expand Down

0 comments on commit 4d4c70f

Please sign in to comment.