Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

require sample axis to be axis=0 #254

Merged
merged 15 commits into from
Oct 25, 2024
141 changes: 106 additions & 35 deletions src/nemos/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import abc
import copy
import inspect
from functools import wraps
from typing import Callable, Generator, Literal, Optional, Tuple, Union

Expand Down Expand Up @@ -468,10 +469,22 @@ class Basis(Base, abc.ABC):
bounds :
The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the
minimum and the maximum of the samples provided when evaluating the basis.
If a sample is outside the bonuds, the basis will return NaN.
If a sample is outside the bounds, the basis will return NaN.
**kwargs :
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`
Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when
`mode='conv'`; These arguments are used to change the default behavior of the convolution.
For example, changing the `predictor_causality`, which by default is set to `"causal"`.
Note that one cannot change the default value for the `axis` parameter. Basis assumes
that the convolution axis is `axis=0`.
Raises
------
ValueError:
- If `mode` is not 'eval' or 'conv'.
- If `kwargs` are not None and `mode =="eval".
- If `kwargs` include parameters not recognized or do not have
default values in `create_convolutional_predictor`.
- If `axis` different from 0 is provided as a keyword argument (samples must always be in the first axis).
"""

def __init__(
Expand All @@ -496,14 +509,58 @@ def __init__(
self.window_size = window_size
self.bounds = bounds

if mode == "eval" and kwargs:
raise ValueError(
f"kwargs should only be set when mode=='conv', but '{mode}' provided instead!"
)
self._check_convolution_kwargs()

self.kernel_ = None
self._identifiability_constraints = False

def _check_convolution_kwargs(self):
"""Check convolution kwargs settings.
Raises
------
ValueError:
- If `self._conv_kwargs` are not None and `mode =="eval".
- If `axis` is provided as an argument, and it is different from 0
(samples must always be in the first axis).
- If `self._conv_kwargs` include parameters not recognized or that do not have
default values in `create_convolutional_predictor`.
"""
if self._mode == "eval" and self._conv_kwargs:
raise ValueError(
f"kwargs should only be set when mode=='conv', but '{self._mode}' provided instead!"
)

if "axis" in self._conv_kwargs:
raise ValueError(
"Setting the `axis` parameter is not allowed. Basis requires the "
"convolution to be applied along the first axis (`axis=0`).\n"
"Please transpose your input so that the desired axis for "
"convolution is the first dimension (axis=0)."
)
convolve_params = inspect.signature(create_convolutional_predictor).parameters
convolve_configs = {
key
for key, param in convolve_params.items()
if param.default
# prevent user from passing
# `basis_matrix` or `time_series` in kwargs.
is not inspect.Parameter.empty
}
if not set(self._conv_kwargs.keys()).issubset(convolve_configs):
# do not encourage to set axis.
convolve_configs = convolve_configs.difference({"axis"})
# remove the parameter in case axis=0 was passed, since it is allowed.
invalid = (
set(self._conv_kwargs.keys())
.difference(convolve_configs)
.difference({"axis"})
)
raise ValueError(
f"Unrecognized keyword arguments: {invalid}. "
f"Allowed convolution keyword arguments are: {convolve_configs}."
)

@property
def n_basis_funcs(self):
return self._n_basis_funcs
Expand Down Expand Up @@ -666,18 +723,11 @@ def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix:
if self.mode == "eval": # evaluate at the sample
return self.__call__(*xi)
else: # convolve, called only at the last layer
if "axis" not in self._conv_kwargs:
axis = 0
else:
axis = self._conv_kwargs["axis"]
# convolve called at the end of any recursive call
# this ensures that len(xi) == 1.
conv = create_convolutional_predictor(
self.kernel_, *xi, **self._conv_kwargs
)
# move the time axis to the first dimension
new_axis = (np.arange(conv.ndim) + axis) % conv.ndim
conv = np.transpose(conv, new_axis)
# make sure to return a matrix
return np.reshape(conv, newshape=(conv.shape[0], -1))

Expand Down Expand Up @@ -1289,10 +1339,13 @@ class SplineBasis(Basis, abc.ABC):
bounds :
The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the
minimum and the maximum of the samples provided when evaluating the basis.
If a sample is outside the bonuds, the basis will return NaN.
If a sample is outside the bounds, the basis will return NaN.
**kwargs :
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`
Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when
`mode='conv'`; These arguments are used to change the default behavior of the convolution.
For example, changing the `predictor_causality`, which by default is set to `"causal"`.
Note that one cannot change the default value for the `axis` parameter. Basis assumes
that the convolution axis is `axis=0`.
Attributes
----------
Expand Down Expand Up @@ -1440,10 +1493,13 @@ class MSplineBasis(SplineBasis):
bounds :
The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the
minimum and the maximum of the samples provided when evaluating the basis.
If a sample is outside the bonuds, the basis will return NaN.
If a sample is outside the bounds, the basis will return NaN.
**kwargs:
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`
Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when
`mode='conv'`; These arguments are used to change the default behavior of the convolution.
For example, changing the `predictor_causality`, which by default is set to `"causal"`.
Note that one cannot change the default value for the `axis` parameter. Basis assumes
that the convolution axis is `axis=0`.
Examples
--------
Expand Down Expand Up @@ -1598,10 +1654,13 @@ class BSplineBasis(SplineBasis):
bounds :
The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the
minimum and the maximum of the samples provided when evaluating the basis.
If a sample is outside the bonuds, the basis will return NaN.
If a sample is outside the bounds, the basis will return NaN.
**kwargs :
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`
Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when
`mode='conv'`; These arguments are used to change the default behavior of the convolution.
For example, changing the `predictor_causality`, which by default is set to `"causal"`.
Note that one cannot change the default value for the `axis` parameter. Basis assumes
that the convolution axis is `axis=0`.
Attributes
----------
Expand Down Expand Up @@ -1717,10 +1776,13 @@ class CyclicBSplineBasis(SplineBasis):
bounds :
The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the
minimum and the maximum of the samples provided when evaluating the basis.
If a sample is outside the bonuds, the basis will return NaN.
If a sample is outside the bounds, the basis will return NaN.
**kwargs :
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`
Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when
`mode='conv'`; These arguments are used to change the default behavior of the convolution.
For example, changing the `predictor_causality`, which by default is set to `"causal"`.
Note that one cannot change the default value for the `axis` parameter. Basis assumes
that the convolution axis is `axis=0`.
Attributes
----------
Expand Down Expand Up @@ -1859,10 +1921,13 @@ class RaisedCosineBasisLinear(Basis):
bounds :
The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the
minimum and the maximum of the samples provided when evaluating the basis.
If a sample is outside the bonuds, the basis will return NaN.
If a sample is outside the bounds, the basis will return NaN.
**kwargs :
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`
Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when
`mode='conv'`; These arguments are used to change the default behavior of the convolution.
For example, changing the `predictor_causality`, which by default is set to `"causal"`.
Note that one cannot change the default value for the `axis` parameter. Basis assumes
that the convolution axis is `axis=0`.
# References
------------
Expand Down Expand Up @@ -2052,10 +2117,13 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear):
bounds :
The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the
minimum and the maximum of the samples provided when evaluating the basis.
If a sample is outside the bonuds, the basis will return NaN.
If a sample is outside the bounds, the basis will return NaN.
**kwargs :
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`
Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when
`mode='conv'`; These arguments are used to change the default behavior of the convolution.
For example, changing the `predictor_causality`, which by default is set to `"causal"`.
Note that one cannot change the default value for the `axis` parameter. Basis assumes
that the convolution axis is `axis=0`.
# References
------------
Expand Down Expand Up @@ -2206,10 +2274,13 @@ class OrthExponentialBasis(Basis):
bounds :
The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the
minimum and the maximum of the samples provided when evaluating the basis.
If a sample is outside the bonuds, the basis will return NaN.
If a sample is outside the bounds, the basis will return NaN.
**kwargs :
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`
Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when
`mode='conv'`; These arguments are used to change the default behavior of the convolution.
For example, changing the `predictor_causality`, which by default is set to `"causal"`.
Note that one cannot change the default value for the `axis` parameter. Basis assumes
that the convolution axis is `axis=0`.
"""

def __init__(
Expand Down
Loading
Loading