Skip to content

Commit

Permalink
Merge branch 'development' into document_basis
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Dec 11, 2024
2 parents bcf72ce + 74bde2b commit 2eb55a9
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
2 changes: 2 additions & 0 deletions src/nemos/basis/_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,7 @@ def split_by_feature(
-------
dict
A dictionary where:
- **Key**: Label of the basis.
- **Value**: the array reshaped to: ``(..., n_inputs, n_basis_funcs, ...)``
"""
Expand Down Expand Up @@ -1039,6 +1040,7 @@ def split_by_feature(
-------
dict
A dictionary where:
- **Keys**: Labels of the additive basis components.
- **Values**: Sub-arrays corresponding to each component. Each sub-array has the shape:
Expand Down
27 changes: 14 additions & 13 deletions src/nemos/basis/_basis_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,17 @@ def __init__(self, bounds: Optional[Tuple[float, float]] = None):
self.bounds = bounds

def _compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor):
"""
Evaluate basis at sample points.
"""Evaluate basis at sample points.
The basis is evaluated at the locations specified in the inputs. For example,
``compute_features(np.array([0, .5]))`` would return the array:
.. code-block:: text
The basis evaluated at the samples, or :math:`b_i(*xi)`, where :math:`b_i` is a
basis element. xi[k] must be a one-dimensional array or a pynapple Tsd/TsdFrame/TsdTensor.
b_1(0) ... b_n(0)
b_1(.5) ... b_n(.5)
where ``b_i`` is the i-th basis.
Parameters
----------
Expand Down Expand Up @@ -89,20 +95,15 @@ def __init__(self, window_size: int, conv_kwargs: Optional[dict] = None):
self.conv_kwargs = {} if conv_kwargs is None else conv_kwargs

def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor):
"""
Convolve basis functions with input time series.
"""Convolve basis functions with input time series.
A bank of basis filters (created by calling fit) is convolved with the
input data. Inputs can be a NDArray, or a pynapple Tsd/TsdFrame/TsdTensor. All the dimensions
except for the sample-axis are flattened, so that the method always returns a matrix.
A bank of basis filters is convolved with the input data. All the dimensions
except for the sample-axis are flattened, so that the method always returns a
matrix.
For example, if inputs are of shape (num_samples, 2, 3), the output will be
``(num_samples, num_basis_funcs * 2 * 3)``.
The time-axis can be specified at basis initialization by setting the keyword argument ``axis``.
For example, if ``axis == 1`` your input should be of shape ``(N1, num_samples N3, ...)``, the output of
transform will be of shape ``(num_samples, num_basis_funcs * N1 * N3 *...)``.
Parameters
----------
*xi:
Expand Down

0 comments on commit 2eb55a9

Please sign in to comment.