diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index f3f76d7b..07403e06 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -683,6 +683,7 @@ def split_by_feature( ------- dict A dictionary where: + - **Key**: Label of the basis. - **Value**: the array reshaped to: ``(..., n_inputs, n_basis_funcs, ...)`` """ @@ -1039,6 +1040,7 @@ def split_by_feature( ------- dict A dictionary where: + - **Keys**: Labels of the additive basis components. - **Values**: Sub-arrays corresponding to each component. Each sub-array has the shape: diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 00bfbcfc..714a9141 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -21,11 +21,17 @@ def __init__(self, bounds: Optional[Tuple[float, float]] = None): self.bounds = bounds def _compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor): - """ - Evaluate basis at sample points. + """Evaluate basis at sample points. + + The basis is evaluated at the locations specified in the inputs. For example, + ``compute_features(np.array([0, .5]))`` would return the array: + + .. code-block:: text - The basis evaluated at the samples, or :math:`b_i(*xi)`, where :math:`b_i` is a - basis element. xi[k] must be a one-dimensional array or a pynapple Tsd/TsdFrame/TsdTensor. + b_1(0) ... b_n(0) + b_1(.5) ... b_n(.5) + + where ``b_i`` is the i-th basis. Parameters ---------- @@ -89,20 +95,15 @@ def __init__(self, window_size: int, conv_kwargs: Optional[dict] = None): self.conv_kwargs = {} if conv_kwargs is None else conv_kwargs def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor): - """ - Convolve basis functions with input time series. + """Convolve basis functions with input time series. - A bank of basis filters (created by calling fit) is convolved with the - input data. Inputs can be a NDArray, or a pynapple Tsd/TsdFrame/TsdTensor. All the dimensions - except for the sample-axis are flattened, so that the method always returns a matrix. + A bank of basis filters is convolved with the input data. All the dimensions + except for the sample-axis are flattened, so that the method always returns a + matrix. For example, if inputs are of shape (num_samples, 2, 3), the output will be ``(num_samples, num_basis_funcs * 2 * 3)``. - The time-axis can be specified at basis initialization by setting the keyword argument ``axis``. - For example, if ``axis == 1`` your input should be of shape ``(N1, num_samples N3, ...)``, the output of - transform will be of shape ``(num_samples, num_basis_funcs * N1 * N3 *...)``. - Parameters ---------- *xi: