From 9d6fbae248ab921bfe30e71e10be1dbadb1d9d90 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Sat, 14 Dec 2024 11:05:54 -0500 Subject: [PATCH] fixed rendering merge completed --- src/nemos/basis/_decaying_exponential.py | 3 +++ src/nemos/basis/_raised_cosine_basis.py | 5 +++++ src/nemos/basis/_spline_basis.py | 16 ++++++++++++++++ 3 files changed, 24 insertions(+) diff --git a/src/nemos/basis/_decaying_exponential.py b/src/nemos/basis/_decaying_exponential.py index a1fd4a24..5f80df58 100644 --- a/src/nemos/basis/_decaying_exponential.py +++ b/src/nemos/basis/_decaying_exponential.py @@ -22,6 +22,8 @@ class OrthExponentialBasis(Basis, AtomicBasisMixin, abc.ABC): Parameters ---------- + n_basis_funcs + Number of basis functions. decay_rates : Decay rates of the exponentials, shape ``(n_basis_funcs,)``. mode : @@ -34,6 +36,7 @@ class OrthExponentialBasis(Basis, AtomicBasisMixin, abc.ABC): def __init__( self, + n_basis_funcs: int, decay_rates: NDArray[np.floating], mode="eval", label: Optional[str] = "OrthExponentialBasis", diff --git a/src/nemos/basis/_raised_cosine_basis.py b/src/nemos/basis/_raised_cosine_basis.py index c964f1dc..dbf039eb 100644 --- a/src/nemos/basis/_raised_cosine_basis.py +++ b/src/nemos/basis/_raised_cosine_basis.py @@ -22,6 +22,8 @@ class RaisedCosineBasisLinear(Basis, AtomicBasisMixin, abc.ABC): Parameters ---------- + n_basis_funcs : + The number of basis functions. mode : The mode of operation. 'eval' for evaluation at sample points, 'conv' for convolutional operation. @@ -41,6 +43,7 @@ class RaisedCosineBasisLinear(Basis, AtomicBasisMixin, abc.ABC): def __init__( self, + n_basis_funcs: int, mode="eval", width: float = 2.0, label: Optional[str] = "RaisedCosineBasisLinear", @@ -232,6 +235,7 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear, abc.ABC): def __init__( self, + n_basis_funcs: int, mode="eval", width: float = 2.0, time_scaling: float = None, @@ -239,6 +243,7 @@ def __init__( label: Optional[str] = "RaisedCosineBasisLog", ) -> None: super().__init__( + n_basis_funcs, mode=mode, width=width, label=label, diff --git a/src/nemos/basis/_spline_basis.py b/src/nemos/basis/_spline_basis.py index 78cc34a6..c8f42d90 100644 --- a/src/nemos/basis/_spline_basis.py +++ b/src/nemos/basis/_spline_basis.py @@ -22,6 +22,8 @@ class SplineBasis(Basis, AtomicBasisMixin, abc.ABC): Parameters ---------- + n_basis_funcs : + Number of basis functions. mode : The mode of operation. 'eval' for evaluation at sample points, 'conv' for convolutional operation. @@ -39,6 +41,7 @@ class SplineBasis(Basis, AtomicBasisMixin, abc.ABC): def __init__( self, + n_basis_funcs: int, order: int = 2, label: Optional[str] = None, mode: Literal["conv", "eval"] = "eval", @@ -156,6 +159,9 @@ class MSplineBasis(SplineBasis, abc.ABC): Parameters ---------- + n_basis_funcs : + The number of basis functions to generate. More basis functions allow for + more flexible data modeling but can lead to overfitting. mode : The mode of operation. 'eval' for evaluation at sample points, 'conv' for convolutional operation. @@ -193,11 +199,13 @@ class MSplineBasis(SplineBasis, abc.ABC): def __init__( self, + n_basis_funcs: int, mode: Literal["eval", "conv"] = "eval", order: int = 2, label: Optional[str] = "MSplineEval", ) -> None: super().__init__( + n_basis_funcs, mode=mode, order=order, label=label, @@ -294,6 +302,8 @@ class BSplineBasis(SplineBasis, abc.ABC): Parameters ---------- + n_basis_funcs : + Number of basis functions. mode : The mode of operation. ``'eval'`` for evaluation at sample points, 'conv' for convolutional operation. @@ -319,11 +329,13 @@ class BSplineBasis(SplineBasis, abc.ABC): def __init__( self, + n_basis_funcs: int, mode="eval", order: int = 4, label: Optional[str] = "BSplineBasis", ): super().__init__( + n_basis_funcs, mode=mode, order=order, label=label, @@ -408,6 +420,8 @@ class CyclicBSplineBasis(SplineBasis, abc.ABC): Parameters ---------- + n_basis_funcs : + Number of basis functions. mode : The mode of operation. 'eval' for evaluation at sample points, 'conv' for convolutional operation. @@ -429,11 +443,13 @@ class CyclicBSplineBasis(SplineBasis, abc.ABC): def __init__( self, + n_basis_funcs: int, mode="eval", order: int = 4, label: Optional[str] = "CyclicBSplineBasis", ): super().__init__( + n_basis_funcs, mode=mode, order=order, label=label,