Skip to content

Commit

Permalink
fixed rendering merge completed
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Dec 14, 2024
1 parent 4c4758e commit 9d6fbae
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/nemos/basis/_decaying_exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class OrthExponentialBasis(Basis, AtomicBasisMixin, abc.ABC):
Parameters
----------
n_basis_funcs
Number of basis functions.
decay_rates :
Decay rates of the exponentials, shape ``(n_basis_funcs,)``.
mode :
Expand All @@ -34,6 +36,7 @@ class OrthExponentialBasis(Basis, AtomicBasisMixin, abc.ABC):

def __init__(
self,
n_basis_funcs: int,
decay_rates: NDArray[np.floating],
mode="eval",
label: Optional[str] = "OrthExponentialBasis",
Expand Down
5 changes: 5 additions & 0 deletions src/nemos/basis/_raised_cosine_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class RaisedCosineBasisLinear(Basis, AtomicBasisMixin, abc.ABC):
Parameters
----------
n_basis_funcs :
The number of basis functions.
mode :
The mode of operation. 'eval' for evaluation at sample points,
'conv' for convolutional operation.
Expand All @@ -41,6 +43,7 @@ class RaisedCosineBasisLinear(Basis, AtomicBasisMixin, abc.ABC):

def __init__(
self,
n_basis_funcs: int,
mode="eval",
width: float = 2.0,
label: Optional[str] = "RaisedCosineBasisLinear",
Expand Down Expand Up @@ -232,13 +235,15 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear, abc.ABC):

def __init__(
self,
n_basis_funcs: int,
mode="eval",
width: float = 2.0,
time_scaling: float = None,
enforce_decay_to_zero: bool = True,
label: Optional[str] = "RaisedCosineBasisLog",
) -> None:
super().__init__(
n_basis_funcs,
mode=mode,
width=width,
label=label,
Expand Down
16 changes: 16 additions & 0 deletions src/nemos/basis/_spline_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class SplineBasis(Basis, AtomicBasisMixin, abc.ABC):
Parameters
----------
n_basis_funcs :
Number of basis functions.
mode :
The mode of operation. 'eval' for evaluation at sample points,
'conv' for convolutional operation.
Expand All @@ -39,6 +41,7 @@ class SplineBasis(Basis, AtomicBasisMixin, abc.ABC):

def __init__(
self,
n_basis_funcs: int,
order: int = 2,
label: Optional[str] = None,
mode: Literal["conv", "eval"] = "eval",
Expand Down Expand Up @@ -156,6 +159,9 @@ class MSplineBasis(SplineBasis, abc.ABC):
Parameters
----------
n_basis_funcs :
The number of basis functions to generate. More basis functions allow for
more flexible data modeling but can lead to overfitting.
mode :
The mode of operation. 'eval' for evaluation at sample points,
'conv' for convolutional operation.
Expand Down Expand Up @@ -193,11 +199,13 @@ class MSplineBasis(SplineBasis, abc.ABC):

def __init__(
self,
n_basis_funcs: int,
mode: Literal["eval", "conv"] = "eval",
order: int = 2,
label: Optional[str] = "MSplineEval",
) -> None:
super().__init__(
n_basis_funcs,
mode=mode,
order=order,
label=label,
Expand Down Expand Up @@ -294,6 +302,8 @@ class BSplineBasis(SplineBasis, abc.ABC):
Parameters
----------
n_basis_funcs :
Number of basis functions.
mode :
The mode of operation. ``'eval'`` for evaluation at sample points,
'conv' for convolutional operation.
Expand All @@ -319,11 +329,13 @@ class BSplineBasis(SplineBasis, abc.ABC):

def __init__(
self,
n_basis_funcs: int,
mode="eval",
order: int = 4,
label: Optional[str] = "BSplineBasis",
):
super().__init__(
n_basis_funcs,
mode=mode,
order=order,
label=label,
Expand Down Expand Up @@ -408,6 +420,8 @@ class CyclicBSplineBasis(SplineBasis, abc.ABC):
Parameters
----------
n_basis_funcs :
Number of basis functions.
mode :
The mode of operation. 'eval' for evaluation at sample points,
'conv' for convolutional operation.
Expand All @@ -429,11 +443,13 @@ class CyclicBSplineBasis(SplineBasis, abc.ABC):

def __init__(
self,
n_basis_funcs: int,
mode="eval",
order: int = 4,
label: Optional[str] = "CyclicBSplineBasis",
):
super().__init__(
n_basis_funcs,
mode=mode,
order=order,
label=label,
Expand Down

0 comments on commit 9d6fbae

Please sign in to comment.