Skip to content

Commit

Permalink
merged conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Dec 9, 2024
2 parents ec3c4c2 + 98c2a74 commit ef3e22b
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 15 deletions.
4 changes: 2 additions & 2 deletions docs/background/plot_01_1D_basis_function.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ if path.exists():
## Feature Computation
The bases in the `nemos.basis` module can be grouped into two categories:

1. **Evaluation Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to evaluate the basis directly, applying a non-linear transformation to the input. Classes in this category have names starting with "Eval," such as `BSplineEval`.
1. **Evaluation Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to evaluate the basis directly, applying a non-linear transformation to the input. Classes in this category have names ending with "Eval," such as `BSplineEval`.

2. **Convolution Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to convolve the input with a kernel of basis elements, using a `window_size` specified by the user. Classes in this category have names starting with "Conv," such as `BSplineConv`.
2. **Convolution Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to convolve the input with a kernel of basis elements, using a `window_size` specified by the user. Classes in this category have names ending with "Conv," such as `BSplineConv`.

Let's see how this two modalities operate.

Expand Down
2 changes: 1 addition & 1 deletion docs/background/plot_03_1D_convolution.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ if path.exists():

## Convolve using [`Basis.compute_features`](nemos.basis._basis.Basis.compute_features)

Every basis in the `nemos.basis` module whose class name starts with "Conv" will perform a 1D convolution over the
Every basis in the `nemos.basis` module whose class name ends with "Conv" will perform a 1D convolution over the
provided input when the `compute_features` method is called. The basis elements will be used as filters for the
convolution.

Expand Down
8 changes: 4 additions & 4 deletions docs/developers_notes/04-basis_module.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Abstract Class Basis

The super-class [`Basis`](nemos.basis._basis.Basis) provides two public methods, [`compute_features`](the-public-method-compute_features) and [`evaluate_on_grid`](the-public-method-evaluate_on_grid). These methods perform checks on both the input provided by the user and the output of the evaluation to ensure correctness, and are thus considered "safe". They both make use of the abstract method `_evaluate` that is specific for each concrete class. See below for more details.

## The Class `nemos.basis._basis.Basis`
## The Abstract Super-class [`Basis`](nemos.basis._basis.Basis)

(the-public-method-compute_features)=
### The Public Method `compute_features`
Expand All @@ -42,7 +42,7 @@ It accepts one or more NumPy array or pynapple `Tsd` object as input, and perfor

1. Checks that the inputs all have the same sample size `M`, and raises a `ValueError` if this is not the case.
2. Checks that the number of inputs matches what the basis being evaluated expects (e.g., one input for a 1-D basis, N inputs for an N-D basis, or the sum of N 1-D bases), and raises a `ValueError` if this is not the case.
3. In `"eval"` mode, calls the `_evaluate` method on the input, which is the subclass-specific implementation of the basis set evaluation. In `"conv"` mode, generates a filter bank using [`compute_features`](nemos.basis._basis.Basis.evaluate_on_grid) and then applies the convolution to the input with [`nemos.convolve.create_convolutional_predictor`](nemos.convolve.create_convolutional_predictor).
3. In `"eval"` mode, calls the [`_evaluate`](nemos.basis._basis.Basis._evaluate) method on the input, which is the subclass-specific implementation of the basis set evaluation. In `"conv"` mode, generates a filter bank using [`compute_features`](nemos.basis._basis.Basis.evaluate_on_grid) and then applies the convolution to the input with [`nemos.convolve.create_convolutional_predictor`](nemos.convolve.create_convolutional_predictor).
4. Returns a NumPy array or pynapple `TsdFrame` of shape `(M, n_basis_funcs)`, with each basis element evaluated at the samples.

:::{admonition} Multiple epochs
Expand All @@ -61,14 +61,14 @@ This method performs the following steps:

1. Checks that the number of inputs matches what the basis being evaluated expects (e.g., one input for a 1-D basis, N inputs for an N-D basis, or the sum of N 1-D bases), and raises a `ValueError` if this is not the case.
2. Calls `_get_samples` method, which returns equidistant samples over the domain of the basis function. The domain may depend on the type of basis.
3. Calls the `_evaluate` method.
3. Calls the [`_evaluate`](nemos.basis._basis.Basis._evaluate) method on these samples.
4. Returns both the sample grid points of shape `(m1, ..., mN)`, and the evaluation output at each grid point of shape `(m1, ..., mN, n_basis_funcs)`, where `mi` is the number of sample points for the i-th axis of the grid.

### Abstract Methods

The [`nemos.basis._basis.Basis`](nemos.basis._basis.Basis) class has the following abstract methods, which every concrete subclass must implement:

1. `_evaluate`: Evaluates a basis over some specified samples.
1. [`_evaluate`](nemos.basis._basis.Basis._evaluate) : Evaluates a basis over some specified samples.
2. `_check_n_basis_min`: Checks the minimum number of basis functions required. This requirement can be specific to the type of basis.

## Contributors Guidelines
Expand Down
11 changes: 4 additions & 7 deletions src/nemos/basis/_basis_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ def _compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor):
Returns
-------
:
A matrix with the transformed features. The basis evaluated at the samples,
or :math:`b_i(*xi)`, where :math:`b_i` is a basis element. xi[k] must be a one-dimensional array
or a pynapple Tsd.
A matrix with the transformed features.
"""
return self._evaluate(*xi)
Expand Down Expand Up @@ -144,15 +142,14 @@ def _set_kernel(self) -> "ConvBasisMixin":
-----
Subclasses implementing this method should detail the specifics of how the kernel is
computed and how the input parameters are utilized.
"""
self.kernel_ = self._evaluate(np.linspace(0, 1, self.window_size))
return self

@property
def window_size(self):
"""Window size as number of samples.
Duration of the convolutional kernel in number of samples.
"""Duration of the convolutional kernel in number of samples.
"""
return self._window_size

Expand All @@ -161,7 +158,7 @@ def window_size(self, window_size):
"""Setter for the window size parameter."""
if window_size is None:
raise ValueError(
"If the basis is in `conv` mode, you must provide a window_size!"
"You must provide a window_size!"
)

elif not (isinstance(window_size, int) and window_size > 0):
Expand Down
2 changes: 1 addition & 1 deletion src/nemos/basis/_transformer_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(self, basis: Basis):

@staticmethod
def _unpack_inputs(X: FeatureMatrix):
"""Unpack impute without using transpose.
"""Unpack inputs without using transpose.
Unpack horizontally stacked inputs using slicing. This works gracefully with ``pynapple``,
returning a list of Tsd objects. Attempt to unpack using *X.T will raise a ``pynapple``
Expand Down

0 comments on commit ef3e22b

Please sign in to comment.