Skip to content

Commit

Permalink
linted
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Dec 5, 2024
1 parent ad45ea7 commit 24697c2
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 76 deletions.
11 changes: 6 additions & 5 deletions src/nemos/basis/_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ def _check_input_shape_consistency(self, x: NDArray):
"different shape, please create a new basis instance."
)

def set_input_shape(self, xi: int | tuple[int,...] | NDArray):
def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
"""
Set the expected input shape for the basis object.
Expand Down Expand Up @@ -803,7 +803,9 @@ def set_input_shape(self, xi: int | tuple[int,...] | NDArray):
"""
if isinstance(xi, tuple):
if not all(isinstance(i, int) for i in xi):
raise ValueError(f"The tuple provided contains non integer values. Tuple: {xi}.")
raise ValueError(
f"The tuple provided contains non integer values. Tuple: {xi}."
)
shape = xi
elif isinstance(xi, int):
shape = () if xi == 1 else (xi,)
Expand Down Expand Up @@ -866,7 +868,7 @@ def __init__(self, basis1: Basis, basis2: Basis) -> None:
self._basis2 = basis2
CompositeBasisMixin.__init__(self)

def set_input_shape(self, *xi: int | tuple[int,...] | NDArray) -> Basis:
def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis:
"""
Set the expected input shape for the basis object.
Expand Down Expand Up @@ -1237,7 +1239,6 @@ def __init__(self, basis1: Basis, basis2: Basis) -> None:
BasisTransformerMixin.__init__(self)
CompositeBasisMixin.__init__(self)


def set_kernel(self, *xi: NDArray) -> Basis:
"""Call fit on the multiplied basis.
Expand Down Expand Up @@ -1323,7 +1324,7 @@ def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix:
)
return X

def set_input_shape(self, *xi: int | tuple[int,...] | NDArray) -> Basis:
def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis:
"""
Set the expected input shape for the basis object.
Expand Down
11 changes: 8 additions & 3 deletions src/nemos/basis/_basis_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import copy
import inspect
from typing import Optional, Tuple, Union, TYPE_CHECKING
from typing import TYPE_CHECKING, Optional, Tuple, Union

import numpy as np
from numpy.typing import NDArray
Expand All @@ -15,6 +15,7 @@
if TYPE_CHECKING:
from ._basis import Basis


class EvalBasisMixin:
"""Mixin class for evaluational basis."""

Expand Down Expand Up @@ -293,5 +294,9 @@ def set_kernel(self, *xi: NDArray) -> Basis:

def _check_input_shape_consistency(self, *xi: NDArray):
"""Check the input shape consistency for all basis elements."""
self._basis1._check_input_shape_consistency(*xi[: self._basis1._n_input_dimensionality])
self._basis2._check_input_shape_consistency(*xi[self._basis1._n_input_dimensionality:])
self._basis1._check_input_shape_consistency(
*xi[: self._basis1._n_input_dimensionality]
)
self._basis2._check_input_shape_consistency(
*xi[self._basis1._n_input_dimensionality :]
)
27 changes: 14 additions & 13 deletions src/nemos/basis/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
return super().evaluate_on_grid(n_samples)

@add_docstring("set_input_shape", BSplineBasis)
def set_input_shape(self, xi: int | tuple[int,...] | NDArray):
def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
"""
Examples
--------
Expand Down Expand Up @@ -186,7 +186,6 @@ def set_input_shape(self, xi: int | tuple[int,...] | NDArray):
return super().set_input_shape(xi)



class BSplineConv(ConvBasisMixin, BSplineBasis):
"""
B-spline 1-dimensional basis functions.
Expand Down Expand Up @@ -307,7 +306,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
return super().evaluate_on_grid(n_samples)

@add_docstring("set_input_shape", BSplineBasis)
def set_input_shape(self, xi: int | tuple[int,...] | NDArray):
def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
"""
Examples
--------
Expand Down Expand Up @@ -448,7 +447,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
return super().evaluate_on_grid(n_samples)

@add_docstring("set_input_shape", CyclicBSplineBasis)
def set_input_shape(self, xi: int | tuple[int,...] | NDArray):
def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
"""
Examples
--------
Expand Down Expand Up @@ -588,7 +587,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
return super().evaluate_on_grid(n_samples)

@add_docstring("set_input_shape", CyclicBSplineBasis)
def set_input_shape(self, xi: int | tuple[int,...] | NDArray):
def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
"""
Examples
--------
Expand Down Expand Up @@ -753,7 +752,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
return super().evaluate_on_grid(n_samples)

@add_docstring("set_input_shape", MSplineBasis)
def set_input_shape(self, xi: int | tuple[int,...] | NDArray):
def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
"""
Examples
--------
Expand Down Expand Up @@ -917,7 +916,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
return super().evaluate_on_grid(n_samples)

@add_docstring("set_input_shape", MSplineBasis)
def set_input_shape(self, xi: int | tuple[int,...] | NDArray):
def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
"""
Examples
--------
Expand Down Expand Up @@ -1061,7 +1060,7 @@ def split_by_feature(
return super().split_by_feature(x, axis=axis)

@add_docstring("set_input_shape", RaisedCosineBasisLinear)
def set_input_shape(self, xi: int | tuple[int,...] | NDArray):
def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
"""
Examples
--------
Expand Down Expand Up @@ -1204,7 +1203,7 @@ def split_by_feature(
return super().split_by_feature(x, axis=axis)

@add_docstring("set_input_shape", RaisedCosineBasisLinear)
def set_input_shape(self, xi: int | tuple[int,...] | NDArray):
def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
"""
Examples
--------
Expand Down Expand Up @@ -1356,7 +1355,7 @@ def split_by_feature(
return super().split_by_feature(x, axis=axis)

@add_docstring("set_input_shape", RaisedCosineBasisLog)
def set_input_shape(self, xi: int | tuple[int,...] | NDArray):
def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
"""
Examples
--------
Expand Down Expand Up @@ -1509,7 +1508,7 @@ def split_by_feature(
return super().split_by_feature(x, axis=axis)

@add_docstring("set_input_shape", RaisedCosineBasisLog)
def set_input_shape(self, xi: int | tuple[int,...] | NDArray):
def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
"""
Examples
--------
Expand All @@ -1536,6 +1535,7 @@ def set_input_shape(self, xi: int | tuple[int,...] | NDArray):
"""
return super().set_input_shape(xi)


class OrthExponentialEval(EvalBasisMixin, OrthExponentialBasis):
"""Set of 1D basis decaying exponential functions numerically orthogonalized.
Expand Down Expand Up @@ -1648,7 +1648,7 @@ def split_by_feature(
return super().split_by_feature(x, axis=axis)

@add_docstring("set_input_shape", OrthExponentialBasis)
def set_input_shape(self, xi: int | tuple[int,...] | NDArray):
def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
"""
Examples
--------
Expand All @@ -1675,6 +1675,7 @@ def set_input_shape(self, xi: int | tuple[int,...] | NDArray):
"""
return super().set_input_shape(xi)


class OrthExponentialConv(ConvBasisMixin, OrthExponentialBasis):
"""Set of 1D basis decaying exponential functions numerically orthogonalized.
Expand Down Expand Up @@ -1785,7 +1786,7 @@ def split_by_feature(
return super().split_by_feature(x, axis=axis)

@add_docstring("set_input_shape", OrthExponentialBasis)
def set_input_shape(self, xi: int | tuple[int,...] | NDArray):
def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
"""
Examples
--------
Expand Down
Loading

0 comments on commit 24697c2

Please sign in to comment.