Skip to content

Commit

Permalink
Merge pull request #254 from flatironinstitute/bug_fix_sample_axis_basis
Browse files Browse the repository at this point in the history
require sample axis to be axis=0
  • Loading branch information
BalzaniEdoardo authored Oct 25, 2024
2 parents 6b3a1e8 + d51f8d3 commit 6645096
Show file tree
Hide file tree
Showing 2 changed files with 326 additions and 41 deletions.
141 changes: 106 additions & 35 deletions src/nemos/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import abc
import copy
import inspect
from functools import wraps
from typing import Callable, Generator, Literal, Optional, Tuple, Union

Expand Down Expand Up @@ -468,10 +469,22 @@ class Basis(Base, abc.ABC):
bounds :
The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the
minimum and the maximum of the samples provided when evaluating the basis.
If a sample is outside the bonuds, the basis will return NaN.
If a sample is outside the bounds, the basis will return NaN.
**kwargs :
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`
Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when
`mode='conv'`; These arguments are used to change the default behavior of the convolution.
For example, changing the `predictor_causality`, which by default is set to `"causal"`.
Note that one cannot change the default value for the `axis` parameter. Basis assumes
that the convolution axis is `axis=0`.
Raises
------
ValueError:
- If `mode` is not 'eval' or 'conv'.
- If `kwargs` are not None and `mode =="eval".
- If `kwargs` include parameters not recognized or do not have
default values in `create_convolutional_predictor`.
- If `axis` different from 0 is provided as a keyword argument (samples must always be in the first axis).
"""

def __init__(
Expand All @@ -496,14 +509,58 @@ def __init__(
self.window_size = window_size
self.bounds = bounds

if mode == "eval" and kwargs:
raise ValueError(
f"kwargs should only be set when mode=='conv', but '{mode}' provided instead!"
)
self._check_convolution_kwargs()

self.kernel_ = None
self._identifiability_constraints = False

def _check_convolution_kwargs(self):
"""Check convolution kwargs settings.
Raises
------
ValueError:
- If `self._conv_kwargs` are not None and `mode =="eval".
- If `axis` is provided as an argument, and it is different from 0
(samples must always be in the first axis).
- If `self._conv_kwargs` include parameters not recognized or that do not have
default values in `create_convolutional_predictor`.
"""
if self._mode == "eval" and self._conv_kwargs:
raise ValueError(
f"kwargs should only be set when mode=='conv', but '{self._mode}' provided instead!"
)

if "axis" in self._conv_kwargs:
raise ValueError(
"Setting the `axis` parameter is not allowed. Basis requires the "
"convolution to be applied along the first axis (`axis=0`).\n"
"Please transpose your input so that the desired axis for "
"convolution is the first dimension (axis=0)."
)
convolve_params = inspect.signature(create_convolutional_predictor).parameters
convolve_configs = {
key
for key, param in convolve_params.items()
if param.default
# prevent user from passing
# `basis_matrix` or `time_series` in kwargs.
is not inspect.Parameter.empty
}
if not set(self._conv_kwargs.keys()).issubset(convolve_configs):
# do not encourage to set axis.
convolve_configs = convolve_configs.difference({"axis"})
# remove the parameter in case axis=0 was passed, since it is allowed.
invalid = (
set(self._conv_kwargs.keys())
.difference(convolve_configs)
.difference({"axis"})
)
raise ValueError(
f"Unrecognized keyword arguments: {invalid}. "
f"Allowed convolution keyword arguments are: {convolve_configs}."
)

@property
def n_basis_funcs(self):
return self._n_basis_funcs
Expand Down Expand Up @@ -666,18 +723,11 @@ def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix:
if self.mode == "eval": # evaluate at the sample
return self.__call__(*xi)
else: # convolve, called only at the last layer
if "axis" not in self._conv_kwargs:
axis = 0
else:
axis = self._conv_kwargs["axis"]
# convolve called at the end of any recursive call
# this ensures that len(xi) == 1.
conv = create_convolutional_predictor(
self.kernel_, *xi, **self._conv_kwargs
)
# move the time axis to the first dimension
new_axis = (np.arange(conv.ndim) + axis) % conv.ndim
conv = np.transpose(conv, new_axis)
# make sure to return a matrix
return np.reshape(conv, newshape=(conv.shape[0], -1))

Expand Down Expand Up @@ -1289,10 +1339,13 @@ class SplineBasis(Basis, abc.ABC):
bounds :
The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the
minimum and the maximum of the samples provided when evaluating the basis.
If a sample is outside the bonuds, the basis will return NaN.
If a sample is outside the bounds, the basis will return NaN.
**kwargs :
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`
Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when
`mode='conv'`; These arguments are used to change the default behavior of the convolution.
For example, changing the `predictor_causality`, which by default is set to `"causal"`.
Note that one cannot change the default value for the `axis` parameter. Basis assumes
that the convolution axis is `axis=0`.
Attributes
----------
Expand Down Expand Up @@ -1440,10 +1493,13 @@ class MSplineBasis(SplineBasis):
bounds :
The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the
minimum and the maximum of the samples provided when evaluating the basis.
If a sample is outside the bonuds, the basis will return NaN.
If a sample is outside the bounds, the basis will return NaN.
**kwargs:
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`
Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when
`mode='conv'`; These arguments are used to change the default behavior of the convolution.
For example, changing the `predictor_causality`, which by default is set to `"causal"`.
Note that one cannot change the default value for the `axis` parameter. Basis assumes
that the convolution axis is `axis=0`.
Examples
--------
Expand Down Expand Up @@ -1598,10 +1654,13 @@ class BSplineBasis(SplineBasis):
bounds :
The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the
minimum and the maximum of the samples provided when evaluating the basis.
If a sample is outside the bonuds, the basis will return NaN.
If a sample is outside the bounds, the basis will return NaN.
**kwargs :
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`
Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when
`mode='conv'`; These arguments are used to change the default behavior of the convolution.
For example, changing the `predictor_causality`, which by default is set to `"causal"`.
Note that one cannot change the default value for the `axis` parameter. Basis assumes
that the convolution axis is `axis=0`.
Attributes
----------
Expand Down Expand Up @@ -1717,10 +1776,13 @@ class CyclicBSplineBasis(SplineBasis):
bounds :
The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the
minimum and the maximum of the samples provided when evaluating the basis.
If a sample is outside the bonuds, the basis will return NaN.
If a sample is outside the bounds, the basis will return NaN.
**kwargs :
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`
Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when
`mode='conv'`; These arguments are used to change the default behavior of the convolution.
For example, changing the `predictor_causality`, which by default is set to `"causal"`.
Note that one cannot change the default value for the `axis` parameter. Basis assumes
that the convolution axis is `axis=0`.
Attributes
----------
Expand Down Expand Up @@ -1859,10 +1921,13 @@ class RaisedCosineBasisLinear(Basis):
bounds :
The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the
minimum and the maximum of the samples provided when evaluating the basis.
If a sample is outside the bonuds, the basis will return NaN.
If a sample is outside the bounds, the basis will return NaN.
**kwargs :
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`
Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when
`mode='conv'`; These arguments are used to change the default behavior of the convolution.
For example, changing the `predictor_causality`, which by default is set to `"causal"`.
Note that one cannot change the default value for the `axis` parameter. Basis assumes
that the convolution axis is `axis=0`.
# References
------------
Expand Down Expand Up @@ -2052,10 +2117,13 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear):
bounds :
The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the
minimum and the maximum of the samples provided when evaluating the basis.
If a sample is outside the bonuds, the basis will return NaN.
If a sample is outside the bounds, the basis will return NaN.
**kwargs :
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`
Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when
`mode='conv'`; These arguments are used to change the default behavior of the convolution.
For example, changing the `predictor_causality`, which by default is set to `"causal"`.
Note that one cannot change the default value for the `axis` parameter. Basis assumes
that the convolution axis is `axis=0`.
# References
------------
Expand Down Expand Up @@ -2206,10 +2274,13 @@ class OrthExponentialBasis(Basis):
bounds :
The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the
minimum and the maximum of the samples provided when evaluating the basis.
If a sample is outside the bonuds, the basis will return NaN.
If a sample is outside the bounds, the basis will return NaN.
**kwargs :
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`
Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when
`mode='conv'`; These arguments are used to change the default behavior of the convolution.
For example, changing the `predictor_causality`, which by default is set to `"causal"`.
Note that one cannot change the default value for the `axis` parameter. Basis assumes
that the convolution axis is `axis=0`.
"""

def __init__(
Expand Down
Loading

0 comments on commit 6645096

Please sign in to comment.