Skip to content

Commit

Permalink
re-struct bases
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Dec 13, 2024
1 parent 2778fb9 commit ce5dff2
Show file tree
Hide file tree
Showing 11 changed files with 1,310 additions and 996 deletions.
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@ profile = "black"
# Configure pytest
[tool.pytest.ini_options]
testpaths = ["tests"] # Specify the directory where test files are located
filterwarnings = [
# note the use of single quote below to denote "raw" strings in TOML
'ignore:plotting functions contained within:UserWarning',
'ignore:Tolerance of \d\.\d+e-\d\d reached:RuntimeWarning',
]

[tool.coverage.run]
omit = [
Expand Down
382 changes: 206 additions & 176 deletions src/nemos/basis/_basis.py

Large diffs are not rendered by default.

198 changes: 181 additions & 17 deletions src/nemos/basis/_basis_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from __future__ import annotations

import abc
import copy
import inspect
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, Union

import numpy as np
Expand All @@ -20,8 +22,11 @@
class EvalBasisMixin:
"""Mixin class for evaluational basis."""

def __init__(self, bounds: Optional[Tuple[float, float]] = None):
def __init__(
self, n_basis_funcs: int, bounds: Optional[Tuple[float, float]] = None
):
self.bounds = bounds
self._n_basis_funcs = n_basis_funcs

def _compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor):
"""Evaluate basis at sample points.
Expand Down Expand Up @@ -51,9 +56,32 @@ def _compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor):
out = self._evaluate(*(np.reshape(x, (x.shape[0], -1)) for x in xi))
return np.reshape(out, (out.shape[0], -1))

def set_kernel(self) -> "EvalBasisMixin":
def setup_basis(self, *xi: NDArray) -> Basis:
"""
Prepare or compute the convolutional kernel for the basis functions.
Set all basis states.
This method corresponds sklearn transformer ``fit``. As fit, it must receive the input and
it must set all basis states, i.e. kernel_ and all the states relative to the input shape.
The difference between this method and the transformer ``fit`` is in the expected input structure,
where the transformer ``fit`` method requires the inputs to be concatenated in a 2D array, while here
each input is provided as a separate time series for each basis element.
Parameters
----------
xi:
Input arrays.
Returns
-------
:
The basis with ready for evaluation.
"""
self.set_input_shape(*xi)
return self

def _set_input_independent_states(self) -> "EvalBasisMixin":
"""
Compute all the basis states that do not depend on the input.
For EvalBasisMixin, this method might not perform any operation but simply return the
instance itself, as no kernel preparation is necessary.
Expand Down Expand Up @@ -94,9 +122,13 @@ def bounds(self, values: Union[None, Tuple[float, float]]):
class ConvBasisMixin:
"""Mixin class for convolutional basis."""

def __init__(self, window_size: int, conv_kwargs: Optional[dict] = None):
def __init__(
self, n_basis_funcs: int, window_size: int, conv_kwargs: Optional[dict] = None
):
self.kernel_ = None
self.window_size = window_size
self.conv_kwargs = {} if conv_kwargs is None else conv_kwargs
self._n_basis_funcs = n_basis_funcs

def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor):
"""Convolve basis functions with input time series.
Expand All @@ -114,10 +146,18 @@ def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor):
The input data over which to apply the basis transformation. The samples can be passed
as multiple arguments, each representing a different dimension for multivariate inputs.
Notes
-----
This method is intended to be 1-to-1 mappable to sklearn ``transform`` method of transformer. This
means that for the method to be callable, all the state attributes have to be pre-computed in a
method that is mappable to ``fit``, which for us is ``_fit_basis``. It is fundamental that both
methods behaves like the corresponding transformer method, with the only difference being the input
structure: a single (X, y) pair for the transformer, a number of time series for the Basis.
"""
if self.kernel_ is None:
raise ValueError(
"You must call `_set_kernel` before `_compute_features`! "
"You must call `setup_basis` before `_compute_features`! "
"Convolution kernel is not set."
)
# before calling the convolve, check that the input matches
Expand All @@ -127,6 +167,38 @@ def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor):
# make sure to return a matrix
return np.reshape(conv, newshape=(conv.shape[0], -1))

def setup_basis(self, *xi: NDArray) -> Basis:
"""
Set all basis states.
This method corresponds sklearn transformer ``fit``. As fit, it must receive the input and
it must set all basis states, i.e. kernel_ and all the states relative to the input shape.
The difference between this method and the transformer ``fit`` is in the expected input structure,
where the transformer ``fit`` method requires the inputs to be concatenated in a 2D array, while here
each input is provided as a separate time series for each basis element.
Parameters
----------
xi:
Input arrays.
Returns
-------
:
The basis with ready for evaluation.
"""
self.set_kernel()
self.set_input_shape(*xi)
return self

def _set_input_independent_states(self):
"""
Compute all the basis states that do not depend on the input.
For Conv mixin the only attribute is the kernel.
"""
return self.set_kernel()

def set_kernel(self) -> "ConvBasisMixin":
"""
Prepare or compute the convolutional kernel for the basis functions.
Expand Down Expand Up @@ -160,6 +232,11 @@ def window_size(self):
@window_size.setter
def window_size(self, window_size):
"""Setter for the window size parameter."""
self._check_window_size(window_size)

self._window_size = window_size

def _check_window_size(self, window_size):
if window_size is None:
raise ValueError("You must provide a window_size!")

Expand All @@ -168,8 +245,6 @@ def window_size(self, window_size):
f"`window_size` must be a positive integer. {window_size} provided instead!"
)

self._window_size = window_size

@property
def conv_kwargs(self):
"""The convolutional kwargs.
Expand Down Expand Up @@ -227,6 +302,13 @@ def _check_convolution_kwargs(conv_kwargs: dict):
f"Allowed convolution keyword arguments are: {convolve_configs}."
)

def _check_has_kernel(self) -> None:
"""Check that the kernel is pre-computed."""
if self.kernel_ is None:
raise ValueError(
"You must call `_set_kernel` before `_compute_features` for Conv basis."
)


class BasisTransformerMixin:
"""Mixin class for constructing a transformer."""
Expand All @@ -244,7 +326,7 @@ def to_transformer(self) -> TransformerBasis:
>>> from sklearn.model_selection import GridSearchCV
>>> # load some data
>>> X, y = np.random.normal(size=(30, 1)), np.random.poisson(size=30)
>>> basis = nmo.basis.RaisedCosineLinearEval(10).to_transformer()
>>> basis = nmo.basis.RaisedCosineLinearEval(10).set_input_shape(1).to_transformer()
>>> glm = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=1.)
>>> pipeline = Pipeline([("basis", basis), ("glm", glm)])
>>> param_grid = dict(
Expand All @@ -258,7 +340,7 @@ def to_transformer(self) -> TransformerBasis:
... )
>>> gridsearch = gridsearch.fit(X, y)
"""
return TransformerBasis(copy.deepcopy(self))
return TransformerBasis(self)


class CompositeBasisMixin:
Expand All @@ -268,28 +350,82 @@ class CompositeBasisMixin:
(AdditiveBasis and MultiplicativeBasis).
"""

def __init__(self, basis1: Basis, basis2: Basis):
# deep copy to avoid changes directly to the 1d basis to be reflected
# in the composite basis.
self.basis1 = copy.deepcopy(basis1)
self.basis2 = copy.deepcopy(basis2)

# set parents
self.basis1._parent = self
self.basis2._parent = self

shapes = (
*(bas1._input_shape_ for bas1 in basis1._list_components()),
*(bas2._input_shape_ for bas2 in basis2._list_components()),
)
# if all bases where set, then set input for composition.
set_bases = (s is not None for s in shapes)

if all(set_bases):
# pass down the input shapes
self.set_input_shape(*shapes)
elif any(set_bases):
warnings.warn(
"Only some of the basis where initialized with `set_input_shape`, "
"please initialize the composite basis before computing features.",
category=UserWarning,
)

@property
@abc.abstractmethod
def n_basis_funcs(self):
"""Read only property for composite bases."""
pass

def _check_n_basis_min(self) -> None:
pass

def set_kernel(self, *xi: NDArray) -> Basis:
"""Call set_kernel on the basis elements.
def setup_basis(self, *xi: NDArray) -> Basis:
"""
Set all basis states.
If any of the basis elements is in "conv" mode, it will prepare its kernels for the convolution.
This method corresponds sklearn transformer ``fit``. As fit, it must receive the input and
it must set all basis states, i.e. kernel_ and all the states relative to the input shape.
The difference between this method and the transformer ``fit`` is in the expected input structure,
where the transformer ``fit`` method requires the inputs to be concatenated in a 2D array, while here
each input is provided as a separate time series for each basis element.
Parameters
----------
*xi:
The sample inputs. Unused, necessary to conform to ``scikit-learn`` API.
xi:
Input arrays.
Returns
-------
:
The basis ready to be evaluated.
The basis with ready for evaluation.
"""
self._basis1.set_kernel()
self._basis2.set_kernel()
# setup both input independent
self._set_input_independent_states()

# and input dependent states
self.set_input_shape(*xi)

return self

def _set_input_independent_states(self):
"""
Compute the input dependent states for traversing the composite basis.
Returns
-------
:
The basis with the states stored as attributes of each component.
"""
self.basis1._set_input_independent_states()
self.basis2._set_input_independent_states()

def _check_input_shape_consistency(self, *xi: NDArray):
"""Check the input shape consistency for all basis elements."""
self._basis1._check_input_shape_consistency(
Expand All @@ -298,3 +434,31 @@ def _check_input_shape_consistency(self, *xi: NDArray):
self._basis2._check_input_shape_consistency(
*xi[self._basis1._n_input_dimensionality :]
)

@property
def basis1(self):
return self._basis1

@basis1.setter
def basis1(self, bas: Basis):
self._basis1 = bas

@property
def basis2(self):
return self._basis2

@basis2.setter
def basis2(self, bas: Basis):
self._basis2 = bas

def _list_components(self):
"""List all basis components.
Reimplements the default behavior by iteratively calling _list_components of the
elements.
Returns
-------
A list with all 1d basis components.
"""
return self._basis1._list_components() + self._basis2._list_components()
4 changes: 0 additions & 4 deletions src/nemos/basis/_decaying_exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ class OrthExponentialBasis(Basis, abc.ABC):
Parameters
----------
n_basis_funcs
Number of basis functions.
decay_rates :
Decay rates of the exponentials, shape ``(n_basis_funcs,)``.
mode :
Expand All @@ -35,13 +33,11 @@ class OrthExponentialBasis(Basis, abc.ABC):

def __init__(
self,
n_basis_funcs: int,
decay_rates: NDArray[np.floating],
mode="eval",
label: Optional[str] = "OrthExponentialBasis",
):
super().__init__(
n_basis_funcs,
mode=mode,
label=label,
)
Expand Down
6 changes: 0 additions & 6 deletions src/nemos/basis/_raised_cosine_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ class RaisedCosineBasisLinear(Basis, abc.ABC):
Parameters
----------
n_basis_funcs :
The number of basis functions.
mode :
The mode of operation. 'eval' for evaluation at sample points,
'conv' for convolutional operation.
Expand All @@ -42,13 +40,11 @@ class RaisedCosineBasisLinear(Basis, abc.ABC):

def __init__(
self,
n_basis_funcs: int,
mode="eval",
width: float = 2.0,
label: Optional[str] = "RaisedCosineBasisLinear",
) -> None:
super().__init__(
n_basis_funcs,
mode=mode,
label=label,
)
Expand Down Expand Up @@ -234,15 +230,13 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear, abc.ABC):

def __init__(
self,
n_basis_funcs: int,
mode="eval",
width: float = 2.0,
time_scaling: float = None,
enforce_decay_to_zero: bool = True,
label: Optional[str] = "RaisedCosineBasisLog",
) -> None:
super().__init__(
n_basis_funcs,
mode=mode,
width=width,
label=label,
Expand Down
Loading

0 comments on commit ce5dff2

Please sign in to comment.