diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index 2cd7d035..f3f76d7b 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -242,7 +242,9 @@ def add_constant(x): return X @check_transform_input - def compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: + def compute_features( + self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor + ) -> FeatureMatrix: """ Apply the basis transformation to the input data. @@ -276,7 +278,9 @@ def compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> Feat return self._compute_features(*xi) @abc.abstractmethod - def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: + def _compute_features( + self, *xi: NDArray | Tsd | TsdFrame | TsdTensor + ) -> FeatureMatrix: """Convolve or evaluate the basis.""" pass @@ -901,7 +905,9 @@ def _evaluate(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatri return X @add_docstring("compute_features", Basis) - def compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: + def compute_features( + self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor + ) -> FeatureMatrix: r""" Examples -------- @@ -918,7 +924,9 @@ def compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> Featu """ return super().compute_features(*xi) - def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: + def _compute_features( + self, *xi: NDArray | Tsd | TsdFrame | TsdTensor + ) -> FeatureMatrix: """ Compute features for added bases and concatenate. @@ -1304,7 +1312,9 @@ def _evaluate(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatri ) return X - def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: + def _compute_features( + self, *xi: NDArray | Tsd | TsdFrame | TsdTensor + ) -> FeatureMatrix: """ Compute the features for the multiplied bases, and compute their outer product. @@ -1397,7 +1407,9 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: return super().evaluate_on_grid(*n_samples) @add_docstring("compute_features", Basis) - def compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: + def compute_features( + self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor + ) -> FeatureMatrix: """ Examples -------- diff --git a/src/nemos/basis/_decaying_exponential.py b/src/nemos/basis/_decaying_exponential.py index a474b819..7762487b 100644 --- a/src/nemos/basis/_decaying_exponential.py +++ b/src/nemos/basis/_decaying_exponential.py @@ -9,6 +9,7 @@ import numpy as np import scipy.linalg from numpy.typing import ArrayLike, NDArray +from pynapple import Tsd, TsdFrame, TsdTensor from ..type_casting import support_pynapple from ..typing import FeatureMatrix @@ -19,7 +20,6 @@ min_max_rescale_samples, ) -from pynapple import Tsd, TsdFrame, TsdTensor class OrthExponentialBasis(Basis, abc.ABC): """Set of 1D basis decaying exponential functions numerically orthogonalized. diff --git a/src/nemos/basis/_raised_cosine_basis.py b/src/nemos/basis/_raised_cosine_basis.py index 7070e653..4d14a1a2 100644 --- a/src/nemos/basis/_raised_cosine_basis.py +++ b/src/nemos/basis/_raised_cosine_basis.py @@ -6,6 +6,7 @@ import numpy as np from numpy.typing import ArrayLike, NDArray +from pynapple import Tsd, TsdFrame, TsdTensor from ..type_casting import support_pynapple from ..typing import FeatureMatrix @@ -16,7 +17,6 @@ min_max_rescale_samples, ) -from pynapple import Tsd, TsdFrame, TsdTensor class RaisedCosineBasisLinear(Basis, abc.ABC): """Represent linearly-spaced raised cosine basis functions. diff --git a/src/nemos/basis/_spline_basis.py b/src/nemos/basis/_spline_basis.py index b5fde73c..dda67ab1 100644 --- a/src/nemos/basis/_spline_basis.py +++ b/src/nemos/basis/_spline_basis.py @@ -7,6 +7,7 @@ import numpy as np from numpy.typing import ArrayLike, NDArray +from pynapple import Tsd, TsdFrame, TsdTensor from scipy.interpolate import splev from ..type_casting import support_pynapple @@ -18,7 +19,6 @@ min_max_rescale_samples, ) -from pynapple import Tsd, TsdFrame, TsdTensor class SplineBasis(Basis, abc.ABC): """ @@ -218,7 +218,9 @@ def __init__( @support_pynapple(conv_type="numpy") @check_transform_input @check_one_dimensional - def _evaluate(self, sample_pts: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: + def _evaluate( + self, sample_pts: ArrayLike | Tsd | TsdFrame | TsdTensor + ) -> FeatureMatrix: """ Evaluate the M-spline basis functions at given sample points. @@ -335,7 +337,9 @@ def __init__( @support_pynapple(conv_type="numpy") @check_transform_input @check_one_dimensional - def _evaluate(self, sample_pts: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: + def _evaluate( + self, sample_pts: ArrayLike | Tsd | TsdFrame | TsdTensor + ) -> FeatureMatrix: """ Evaluate the B-spline basis functions with given sample points. diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 17b8f80e..9caea358 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -9,7 +9,7 @@ from ..typing import FeatureMatrix from ._basis import add_docstring -from ._basis_mixin import BasisTransformerMixin, ConvBasisMixin, EvalBasisMixin +from ._basis_mixin import ConvBasisMixin, EvalBasisMixin from ._decaying_exponential import OrthExponentialBasis from ._raised_cosine_basis import RaisedCosineBasisLinear, RaisedCosineBasisLog from ._spline_basis import BSplineBasis, CyclicBSplineBasis, MSplineBasis @@ -794,9 +794,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return super().evaluate_on_grid(n_samples) -class RaisedCosineLinearEval( - EvalBasisMixin, RaisedCosineBasisLinear -): +class RaisedCosineLinearEval(EvalBasisMixin, RaisedCosineBasisLinear): """ Represent linearly-spaced raised cosine basis functions. @@ -910,9 +908,7 @@ def split_by_feature( return super().split_by_feature(x, axis=axis) -class RaisedCosineLinearConv( - ConvBasisMixin, RaisedCosineBasisLinear -): +class RaisedCosineLinearConv(ConvBasisMixin, RaisedCosineBasisLinear): """ Represent linearly-spaced raised cosine basis functions.