Skip to content

Commit

Permalink
linted
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Dec 9, 2024
1 parent 4d4c70f commit ec3c4c2
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 18 deletions.
24 changes: 18 additions & 6 deletions src/nemos/basis/_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,9 @@ def add_constant(x):
return X

@check_transform_input
def compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix:
def compute_features(
self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor
) -> FeatureMatrix:
"""
Apply the basis transformation to the input data.
Expand Down Expand Up @@ -276,7 +278,9 @@ def compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> Feat
return self._compute_features(*xi)

@abc.abstractmethod
def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix:
def _compute_features(
self, *xi: NDArray | Tsd | TsdFrame | TsdTensor
) -> FeatureMatrix:
"""Convolve or evaluate the basis."""
pass

Expand Down Expand Up @@ -901,7 +905,9 @@ def _evaluate(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatri
return X

@add_docstring("compute_features", Basis)
def compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix:
def compute_features(
self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor
) -> FeatureMatrix:
r"""
Examples
--------
Expand All @@ -918,7 +924,9 @@ def compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> Featu
"""
return super().compute_features(*xi)

def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix:
def _compute_features(
self, *xi: NDArray | Tsd | TsdFrame | TsdTensor
) -> FeatureMatrix:
"""
Compute features for added bases and concatenate.
Expand Down Expand Up @@ -1304,7 +1312,9 @@ def _evaluate(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatri
)
return X

def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix:
def _compute_features(
self, *xi: NDArray | Tsd | TsdFrame | TsdTensor
) -> FeatureMatrix:
"""
Compute the features for the multiplied bases, and compute their outer product.
Expand Down Expand Up @@ -1397,7 +1407,9 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]:
return super().evaluate_on_grid(*n_samples)

@add_docstring("compute_features", Basis)
def compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix:
def compute_features(
self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor
) -> FeatureMatrix:
"""
Examples
--------
Expand Down
2 changes: 1 addition & 1 deletion src/nemos/basis/_decaying_exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
import scipy.linalg
from numpy.typing import ArrayLike, NDArray
from pynapple import Tsd, TsdFrame, TsdTensor

from ..type_casting import support_pynapple
from ..typing import FeatureMatrix
Expand All @@ -19,7 +20,6 @@
min_max_rescale_samples,
)

from pynapple import Tsd, TsdFrame, TsdTensor

class OrthExponentialBasis(Basis, abc.ABC):
"""Set of 1D basis decaying exponential functions numerically orthogonalized.
Expand Down
2 changes: 1 addition & 1 deletion src/nemos/basis/_raised_cosine_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
from numpy.typing import ArrayLike, NDArray
from pynapple import Tsd, TsdFrame, TsdTensor

from ..type_casting import support_pynapple
from ..typing import FeatureMatrix
Expand All @@ -16,7 +17,6 @@
min_max_rescale_samples,
)

from pynapple import Tsd, TsdFrame, TsdTensor

class RaisedCosineBasisLinear(Basis, abc.ABC):
"""Represent linearly-spaced raised cosine basis functions.
Expand Down
10 changes: 7 additions & 3 deletions src/nemos/basis/_spline_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
from numpy.typing import ArrayLike, NDArray
from pynapple import Tsd, TsdFrame, TsdTensor
from scipy.interpolate import splev

from ..type_casting import support_pynapple
Expand All @@ -18,7 +19,6 @@
min_max_rescale_samples,
)

from pynapple import Tsd, TsdFrame, TsdTensor

class SplineBasis(Basis, abc.ABC):
"""
Expand Down Expand Up @@ -218,7 +218,9 @@ def __init__(
@support_pynapple(conv_type="numpy")
@check_transform_input
@check_one_dimensional
def _evaluate(self, sample_pts: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix:
def _evaluate(
self, sample_pts: ArrayLike | Tsd | TsdFrame | TsdTensor
) -> FeatureMatrix:
"""
Evaluate the M-spline basis functions at given sample points.
Expand Down Expand Up @@ -335,7 +337,9 @@ def __init__(
@support_pynapple(conv_type="numpy")
@check_transform_input
@check_one_dimensional
def _evaluate(self, sample_pts: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix:
def _evaluate(
self, sample_pts: ArrayLike | Tsd | TsdFrame | TsdTensor
) -> FeatureMatrix:
"""
Evaluate the B-spline basis functions with given sample points.
Expand Down
10 changes: 3 additions & 7 deletions src/nemos/basis/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ..typing import FeatureMatrix
from ._basis import add_docstring
from ._basis_mixin import BasisTransformerMixin, ConvBasisMixin, EvalBasisMixin
from ._basis_mixin import ConvBasisMixin, EvalBasisMixin
from ._decaying_exponential import OrthExponentialBasis
from ._raised_cosine_basis import RaisedCosineBasisLinear, RaisedCosineBasisLog
from ._spline_basis import BSplineBasis, CyclicBSplineBasis, MSplineBasis
Expand Down Expand Up @@ -794,9 +794,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
return super().evaluate_on_grid(n_samples)


class RaisedCosineLinearEval(
EvalBasisMixin, RaisedCosineBasisLinear
):
class RaisedCosineLinearEval(EvalBasisMixin, RaisedCosineBasisLinear):
"""
Represent linearly-spaced raised cosine basis functions.
Expand Down Expand Up @@ -910,9 +908,7 @@ def split_by_feature(
return super().split_by_feature(x, axis=axis)


class RaisedCosineLinearConv(
ConvBasisMixin, RaisedCosineBasisLinear
):
class RaisedCosineLinearConv(ConvBasisMixin, RaisedCosineBasisLinear):
"""
Represent linearly-spaced raised cosine basis functions.
Expand Down

0 comments on commit ec3c4c2

Please sign in to comment.