Skip to content

Commit

Permalink
merged basis PR1
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Dec 10, 2024
2 parents 57e2bd2 + 6a6da11 commit bcf72ce
Show file tree
Hide file tree
Showing 13 changed files with 192 additions and 95 deletions.
2 changes: 1 addition & 1 deletion docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ These classes are the building blocks for the concrete basis classes.
AdditiveBasis
MultiplicativeBasis

**Basis As `scikit-learn` Tranformers:**
**Basis As ``scikit-learn`` Tranformers:**

.. currentmodule:: nemos.basis._transformer_basis

Expand Down
1 change: 1 addition & 0 deletions docs/background/basis/plot_01_1D_basis_function.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ We can be group the bases into two categories depending on the type of transform

2. **Convolution Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to convolve the input with a kernel of basis elements, using a `window_size` specified by the user. Classes in this category have names ending with "Conv", such as `BSplineConv`.


Let's see how this two modalities operate.

```{code-cell} ipython3
Expand Down
2 changes: 1 addition & 1 deletion docs/background/plot_03_1D_convolution.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ if path.exists():

## Convolve using [`Basis.compute_features`](nemos.basis._basis.Basis.compute_features)

Every basis in the `nemos.basis` module whose class name starts with "Conv" will perform a 1D convolution over the
Every basis in the `nemos.basis` module whose class name ends with "Conv" will perform a 1D convolution over the
provided input when the `compute_features` method is called. The basis elements will be used as filters for the
convolution.

Expand Down
6 changes: 3 additions & 3 deletions docs/developers_notes/04-basis_module.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Abstract Class Basis

The super-class [`Basis`](nemos.basis._basis.Basis) provides two public methods, [`compute_features`](the-public-method-compute_features) and [`evaluate_on_grid`](the-public-method-evaluate_on_grid). These methods perform checks on both the input provided by the user and the output of the evaluation to ensure correctness, and are thus considered "safe". They both make use of the abstract method `_evaluate` that is specific for each concrete class. See below for more details.

## The Class `nemos.basis._basis.Basis`
## The Abstract Super-class [`Basis`](nemos.basis._basis.Basis)

(the-public-method-compute_features)=
### The Public Method `compute_features`
Expand Down Expand Up @@ -61,14 +61,14 @@ This method performs the following steps:

1. Checks that the number of inputs matches what the basis being evaluated expects (e.g., one input for a 1-D basis, N inputs for an N-D basis, or the sum of N 1-D bases), and raises a `ValueError` if this is not the case.
2. Calls `_get_samples` method, which returns equidistant samples over the domain of the basis function. The domain may depend on the type of basis.
3. Calls the `_evaluate` method.
3. Calls the `_evaluate` method on these samples.
4. Returns both the sample grid points of shape `(m1, ..., mN)`, and the evaluation output at each grid point of shape `(m1, ..., mN, n_basis_funcs)`, where `mi` is the number of sample points for the i-th axis of the grid.

### Abstract Methods

The [`nemos.basis._basis.Basis`](nemos.basis._basis.Basis) class has the following abstract methods, which every concrete subclass must implement:

1. `_evaluate`: Evaluates a basis over some specified samples.
1. `_evaluate` : Evaluates a basis over some specified samples.
2. `_check_n_basis_min`: Checks the minimum number of basis functions required. This requirement can be specific to the type of basis.

## Contributors Guidelines
Expand Down
8 changes: 7 additions & 1 deletion docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,18 @@ sns.despine(ax=ax)
### Converting NeMoS `Basis` to a transformer
In order to use NeMoS [`Basis`](nemos.basis._basis.Basis) in a pipeline, we need to convert it into a scikit-learn transformer. This can be achieved through the [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) wrapper class.

Instantiating a [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) can be done with [`Basis.to_transformer()`](nemos.basis._basis.Basis.to_transformer):
Instantiating a [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) can be done either using by the constructor directly or with [`Basis.to_transformer()`](nemos.basis._basis.Basis.to_transformer):


```{code-cell} ipython3
bas = nmo.basis.RaisedCosineLinearConv(5, window_size=5)
# initalize using the constructor
trans_bas = nmo.basis.TransformerBasis(bas)
# equivalent initialization via "to_transformer"
trans_bas = bas.to_transformer()
```

[`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) provides convenient access to the underlying [`Basis`](nemos.basis._basis.Basis) object's attributes:
Expand Down
137 changes: 93 additions & 44 deletions src/nemos/basis/_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import jax
import numpy as np
from numpy.typing import ArrayLike, NDArray
from pynapple import Tsd, TsdFrame
from pynapple import Tsd, TsdFrame, TsdTensor

from ..base_class import Base
from ..type_casting import support_pynapple
Expand Down Expand Up @@ -242,21 +242,23 @@ def add_constant(x):
return X

@check_transform_input
def compute_features(self, *xi: ArrayLike) -> FeatureMatrix:
def compute_features(
self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor
) -> FeatureMatrix:
"""
Apply the basis transformation to the input data.
This method is designed to be a high-level interface for transforming input
data using the basis functions defined by the subclass. Depending on the basis'
mode ('eval' or 'conv'), it either evaluates the basis functions at the sample
mode ('Eval' or 'Conv'), it either evaluates the basis functions at the sample
points or performs a convolution operation between the input data and the
basis functions.
Parameters
----------
*xi :
Input data arrays to be transformed. The shape and content requirements
depend on the subclass and mode of operation ('eval' or 'conv').
depend on the subclass and mode of operation ('Eval' or 'Conv').
Returns
-------
Expand All @@ -276,7 +278,9 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix:
return self._compute_features(*xi)

@abc.abstractmethod
def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix:
def _compute_features(
self, *xi: NDArray | Tsd | TsdFrame | TsdTensor
) -> FeatureMatrix:
"""Convolve or evaluate the basis."""
pass

Expand All @@ -286,7 +290,7 @@ def _set_kernel(self):
pass

@abc.abstractmethod
def _evaluate(self, *xi: ArrayLike) -> FeatureMatrix:
def _evaluate(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix:
"""
Abstract method to evaluate the basis functions at given points.
Expand Down Expand Up @@ -550,9 +554,6 @@ def _get_feature_slicing(
Calculate and return the slicing for features based on the input structure.
This method determines how to slice the features for different basis types.
If the instance is of ``AdditiveBasis`` type, the slicing is calculated recursively
for each component basis. Otherwise, it determines the slicing based on
the number of basis functions and ``split_by_input`` flag.
Parameters
----------
Expand Down Expand Up @@ -582,37 +583,14 @@ def _get_feature_slicing(
n_inputs = n_inputs or self._n_basis_input
start_slice = start_slice or 0

# If the instance is of AdditiveBasis type, handle slicing for the additive components
if isinstance(self, AdditiveBasis):
split_dict, start_slice = self._basis1._get_feature_slicing(
n_inputs[: len(self._basis1._n_basis_input)],
start_slice,
split_by_input=split_by_input,
)
sp2, start_slice = self._basis2._get_feature_slicing(
n_inputs[len(self._basis1._n_basis_input) :],
start_slice,
split_by_input=split_by_input,
)
split_dict = self._merge_slicing_dicts(split_dict, sp2)
else:
# Handle the default case for other basis types
split_dict, start_slice = self._get_default_slicing(
split_by_input, start_slice
)
# Handle the default case for non-additive basis types
# See overwritten method for recursion logic
split_dict, start_slice = self._get_default_slicing(
split_by_input=split_by_input, start_slice=start_slice
)

return split_dict, start_slice

def _merge_slicing_dicts(self, dict1: dict, dict2: dict) -> dict:
"""Merge two slicing dictionaries, handling key conflicts."""
for key, val in dict2.items():
if key in dict1:
new_key = self._generate_unique_key(dict1, key)
dict1[new_key] = val
else:
dict1[key] = val
return dict1

@staticmethod
def _generate_unique_key(existing_dict: dict, key: str) -> str:
"""Generate a unique key if there is a conflict."""
Expand Down Expand Up @@ -887,7 +865,7 @@ def _check_n_basis_min(self) -> None:
@support_pynapple(conv_type="numpy")
@check_transform_input
@check_one_dimensional
def _evaluate(self, *xi: ArrayLike) -> FeatureMatrix:
def _evaluate(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix:
"""
Evaluate the basis at the input samples.
Expand Down Expand Up @@ -927,7 +905,9 @@ def _evaluate(self, *xi: ArrayLike) -> FeatureMatrix:
return X

@add_docstring("compute_features", Basis)
def compute_features(self, *xi: ArrayLike) -> FeatureMatrix:
def compute_features(
self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor
) -> FeatureMatrix:
r"""
Examples
--------
Expand All @@ -944,7 +924,9 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix:
"""
return super().compute_features(*xi)

def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix:
def _compute_features(
self, *xi: NDArray | Tsd | TsdFrame | TsdTensor
) -> FeatureMatrix:
"""
Compute features for added bases and concatenate.
Expand Down Expand Up @@ -1162,6 +1144,70 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]:
"""
return super().evaluate_on_grid(*n_samples)

def _get_feature_slicing(
self,
n_inputs: Optional[tuple] = None,
start_slice: Optional[int] = None,
split_by_input: bool = True,
) -> Tuple[dict, int]:
"""
Calculate and return the slicing for features based on the input structure.
This method determines how to slice the features for different basis types.
Parameters
----------
n_inputs :
The number of input basis for each component, by default it uses ``self._n_basis_input``.
start_slice :
The starting index for slicing, by default it starts from 0.
split_by_input :
Flag indicating whether to split the slicing by individual inputs or not.
If ``False``, a single slice is generated for all inputs.
Returns
-------
split_dict :
Dictionary with keys as labels and values as slices representing
the slicing for each input or additive component, if split_by_input equals to
True or False respectively.
start_slice :
The updated starting index after slicing.
See Also
--------
_get_default_slicing : Handles default slicing logic.
_merge_slicing_dicts : Merges multiple slicing dictionaries, handling keys conflicts.
"""
# Set default values for n_inputs and start_slice if not provided
n_inputs = n_inputs or self._n_basis_input
start_slice = start_slice or 0

# If the instance is of AdditiveBasis type, handle slicing for the additive components

split_dict, start_slice = self._basis1._get_feature_slicing(
n_inputs[: len(self._basis1._n_basis_input)],
start_slice,
split_by_input=split_by_input,
)
sp2, start_slice = self._basis2._get_feature_slicing(
n_inputs[len(self._basis1._n_basis_input) :],
start_slice,
split_by_input=split_by_input,
)
split_dict = self._merge_slicing_dicts(split_dict, sp2)
return split_dict, start_slice

def _merge_slicing_dicts(self, dict1: dict, dict2: dict) -> dict:
"""Merge two slicing dictionaries, handling key conflicts."""
for key, val in dict2.items():
if key in dict1:
new_key = self._generate_unique_key(dict1, key)
dict1[new_key] = val
else:
dict1[key] = val
return dict1


class MultiplicativeBasis(Basis):
"""
Expand Down Expand Up @@ -1208,7 +1254,6 @@ def __init__(self, basis1: Basis, basis2: Basis) -> None:
self._label = "(" + basis1.label + " * " + basis2.label + ")"
self._basis1 = basis1
self._basis2 = basis2
BasisTransformerMixin.__init__(self)

def _check_n_basis_min(self) -> None:
pass
Expand All @@ -1235,7 +1280,7 @@ def _set_kernel(self, *xi: NDArray) -> Basis:
@support_pynapple(conv_type="numpy")
@check_transform_input
@check_one_dimensional
def _evaluate(self, *xi: ArrayLike) -> FeatureMatrix:
def _evaluate(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix:
"""
Evaluate the basis at the input samples.
Expand Down Expand Up @@ -1267,7 +1312,9 @@ def _evaluate(self, *xi: ArrayLike) -> FeatureMatrix:
)
return X

def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix:
def _compute_features(
self, *xi: NDArray | Tsd | TsdFrame | TsdTensor
) -> FeatureMatrix:
"""
Compute the features for the multiplied bases, and compute their outer product.
Expand Down Expand Up @@ -1360,7 +1407,9 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]:
return super().evaluate_on_grid(*n_samples)

@add_docstring("compute_features", Basis)
def compute_features(self, *xi: ArrayLike) -> FeatureMatrix:
def compute_features(
self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor
) -> FeatureMatrix:
"""
Examples
--------
Expand Down
Loading

0 comments on commit bcf72ce

Please sign in to comment.