diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 1b0c9f12..f0c988ad 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -5,6 +5,7 @@ import abc import copy +import inspect from functools import wraps from typing import Callable, Generator, Literal, Optional, Tuple, Union @@ -468,10 +469,22 @@ class Basis(Base, abc.ABC): bounds : The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bonuds, the basis will return NaN. + If a sample is outside the bounds, the basis will return NaN. **kwargs : - Only used in "conv" mode. Additional keyword arguments that are passed to - `nemos.convolve.create_convolutional_predictor` + Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when + `mode='conv'`; These arguments are used to change the default behavior of the convolution. + For example, changing the `predictor_causality`, which by default is set to `"causal"`. + Note that one cannot change the default value for the `axis` parameter. Basis assumes + that the convolution axis is `axis=0`. + + Raises + ------ + ValueError: + - If `mode` is not 'eval' or 'conv'. + - If `kwargs` are not None and `mode =="eval". + - If `kwargs` include parameters not recognized or do not have + default values in `create_convolutional_predictor`. + - If `axis` different from 0 is provided as a keyword argument (samples must always be in the first axis). """ def __init__( @@ -496,14 +509,58 @@ def __init__( self.window_size = window_size self.bounds = bounds - if mode == "eval" and kwargs: - raise ValueError( - f"kwargs should only be set when mode=='conv', but '{mode}' provided instead!" - ) + self._check_convolution_kwargs() self.kernel_ = None self._identifiability_constraints = False + def _check_convolution_kwargs(self): + """Check convolution kwargs settings. + + Raises + ------ + ValueError: + - If `self._conv_kwargs` are not None and `mode =="eval". + - If `axis` is provided as an argument, and it is different from 0 + (samples must always be in the first axis). + - If `self._conv_kwargs` include parameters not recognized or that do not have + default values in `create_convolutional_predictor`. + """ + if self._mode == "eval" and self._conv_kwargs: + raise ValueError( + f"kwargs should only be set when mode=='conv', but '{self._mode}' provided instead!" + ) + + if "axis" in self._conv_kwargs: + raise ValueError( + "Setting the `axis` parameter is not allowed. Basis requires the " + "convolution to be applied along the first axis (`axis=0`).\n" + "Please transpose your input so that the desired axis for " + "convolution is the first dimension (axis=0)." + ) + convolve_params = inspect.signature(create_convolutional_predictor).parameters + convolve_configs = { + key + for key, param in convolve_params.items() + if param.default + # prevent user from passing + # `basis_matrix` or `time_series` in kwargs. + is not inspect.Parameter.empty + } + if not set(self._conv_kwargs.keys()).issubset(convolve_configs): + # do not encourage to set axis. + convolve_configs = convolve_configs.difference({"axis"}) + # remove the parameter in case axis=0 was passed, since it is allowed. + invalid = ( + set(self._conv_kwargs.keys()) + .difference(convolve_configs) + .difference({"axis"}) + ) + raise ValueError( + f"Unrecognized keyword arguments: {invalid}. " + f"Allowed convolution keyword arguments are: {convolve_configs}." + ) + @property def n_basis_funcs(self): return self._n_basis_funcs @@ -666,18 +723,11 @@ def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix: if self.mode == "eval": # evaluate at the sample return self.__call__(*xi) else: # convolve, called only at the last layer - if "axis" not in self._conv_kwargs: - axis = 0 - else: - axis = self._conv_kwargs["axis"] # convolve called at the end of any recursive call # this ensures that len(xi) == 1. conv = create_convolutional_predictor( self.kernel_, *xi, **self._conv_kwargs ) - # move the time axis to the first dimension - new_axis = (np.arange(conv.ndim) + axis) % conv.ndim - conv = np.transpose(conv, new_axis) # make sure to return a matrix return np.reshape(conv, newshape=(conv.shape[0], -1)) @@ -1289,10 +1339,13 @@ class SplineBasis(Basis, abc.ABC): bounds : The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bonuds, the basis will return NaN. + If a sample is outside the bounds, the basis will return NaN. **kwargs : - Only used in "conv" mode. Additional keyword arguments that are passed to - `nemos.convolve.create_convolutional_predictor` + Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when + `mode='conv'`; These arguments are used to change the default behavior of the convolution. + For example, changing the `predictor_causality`, which by default is set to `"causal"`. + Note that one cannot change the default value for the `axis` parameter. Basis assumes + that the convolution axis is `axis=0`. Attributes ---------- @@ -1440,10 +1493,13 @@ class MSplineBasis(SplineBasis): bounds : The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bonuds, the basis will return NaN. + If a sample is outside the bounds, the basis will return NaN. **kwargs: - Only used in "conv" mode. Additional keyword arguments that are passed to - `nemos.convolve.create_convolutional_predictor` + Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when + `mode='conv'`; These arguments are used to change the default behavior of the convolution. + For example, changing the `predictor_causality`, which by default is set to `"causal"`. + Note that one cannot change the default value for the `axis` parameter. Basis assumes + that the convolution axis is `axis=0`. Examples -------- @@ -1598,10 +1654,13 @@ class BSplineBasis(SplineBasis): bounds : The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bonuds, the basis will return NaN. + If a sample is outside the bounds, the basis will return NaN. **kwargs : - Only used in "conv" mode. Additional keyword arguments that are passed to - `nemos.convolve.create_convolutional_predictor` + Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when + `mode='conv'`; These arguments are used to change the default behavior of the convolution. + For example, changing the `predictor_causality`, which by default is set to `"causal"`. + Note that one cannot change the default value for the `axis` parameter. Basis assumes + that the convolution axis is `axis=0`. Attributes ---------- @@ -1717,10 +1776,13 @@ class CyclicBSplineBasis(SplineBasis): bounds : The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bonuds, the basis will return NaN. + If a sample is outside the bounds, the basis will return NaN. **kwargs : - Only used in "conv" mode. Additional keyword arguments that are passed to - `nemos.convolve.create_convolutional_predictor` + Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when + `mode='conv'`; These arguments are used to change the default behavior of the convolution. + For example, changing the `predictor_causality`, which by default is set to `"causal"`. + Note that one cannot change the default value for the `axis` parameter. Basis assumes + that the convolution axis is `axis=0`. Attributes ---------- @@ -1859,10 +1921,13 @@ class RaisedCosineBasisLinear(Basis): bounds : The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bonuds, the basis will return NaN. + If a sample is outside the bounds, the basis will return NaN. **kwargs : - Only used in "conv" mode. Additional keyword arguments that are passed to - `nemos.convolve.create_convolutional_predictor` + Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when + `mode='conv'`; These arguments are used to change the default behavior of the convolution. + For example, changing the `predictor_causality`, which by default is set to `"causal"`. + Note that one cannot change the default value for the `axis` parameter. Basis assumes + that the convolution axis is `axis=0`. # References ------------ @@ -2052,10 +2117,13 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear): bounds : The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bonuds, the basis will return NaN. + If a sample is outside the bounds, the basis will return NaN. **kwargs : - Only used in "conv" mode. Additional keyword arguments that are passed to - `nemos.convolve.create_convolutional_predictor` + Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when + `mode='conv'`; These arguments are used to change the default behavior of the convolution. + For example, changing the `predictor_causality`, which by default is set to `"causal"`. + Note that one cannot change the default value for the `axis` parameter. Basis assumes + that the convolution axis is `axis=0`. # References ------------ @@ -2206,10 +2274,13 @@ class OrthExponentialBasis(Basis): bounds : The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bonuds, the basis will return NaN. + If a sample is outside the bounds, the basis will return NaN. **kwargs : - Only used in "conv" mode. Additional keyword arguments that are passed to - `nemos.convolve.create_convolutional_predictor` + Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when + `mode='conv'`; These arguments are used to change the default behavior of the convolution. + For example, changing the `predictor_causality`, which by default is set to `"causal"`. + Note that one cannot change the default value for the `axis` parameter. Basis assumes + that the convolution axis is `axis=0`. """ def __init__( diff --git a/tests/test_basis.py b/tests/test_basis.py index 3e21db33..fe10c43b 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -495,6 +495,44 @@ def test_init_mode(self, mode, expectation): with expectation: self.cls(5, mode=mode, window_size=window_size) + @pytest.mark.parametrize( + "conv_kwargs, expectation", + [ + (dict(), does_not_raise()), + ( + dict(axis=0), + pytest.raises( + ValueError, match="Setting the `axis` parameter is not allowed" + ), + ), + ( + dict(axis=1), + pytest.raises( + ValueError, match="Setting the `axis` parameter is not allowed" + ), + ), + (dict(shift=True), does_not_raise()), + ( + dict(shift=True, axis=0), + pytest.raises( + ValueError, match="Setting the `axis` parameter is not allowed" + ), + ), + ( + dict(shifts=True), + pytest.raises(ValueError, match="Unrecognized keyword arguments"), + ), + (dict(shift=True, predictor_causality="causal"), does_not_raise()), + ( + dict(shift=True, time_series=np.arange(10)), + pytest.raises(ValueError, match="Unrecognized keyword arguments"), + ), + ], + ) + def test_init_conv_kwargs(self, conv_kwargs, expectation): + with expectation: + self.cls(5, mode="conv", window_size=200, **conv_kwargs) + @pytest.mark.parametrize( "mode, ws, expectation", [ @@ -567,7 +605,10 @@ def test_set_params( for i in range(len(pars)): for j in range(i + 1, len(pars)): - with pytest.raises(AttributeError, match="can't set attribute 'mode'|property 'mode' of "): + with pytest.raises( + AttributeError, + match="can't set attribute 'mode'|property 'mode' of ", + ): par_set = { keys[i]: pars[keys[i]], keys[j]: pars[keys[j]], @@ -1144,6 +1185,44 @@ def test_init_mode(self, mode, expectation): with expectation: self.cls(5, mode=mode, window_size=window_size) + @pytest.mark.parametrize( + "conv_kwargs, expectation", + [ + (dict(), does_not_raise()), + ( + dict(axis=0), + pytest.raises( + ValueError, match="Setting the `axis` parameter is not allowed" + ), + ), + ( + dict(axis=1), + pytest.raises( + ValueError, match="Setting the `axis` parameter is not allowed" + ), + ), + (dict(shift=True), does_not_raise()), + ( + dict(shift=True, axis=0), + pytest.raises( + ValueError, match="Setting the `axis` parameter is not allowed" + ), + ), + ( + dict(shifts=True), + pytest.raises(ValueError, match="Unrecognized keyword arguments"), + ), + (dict(shift=True, predictor_causality="causal"), does_not_raise()), + ( + dict(shift=True, time_series=np.arange(10)), + pytest.raises(ValueError, match="Unrecognized keyword arguments"), + ), + ], + ) + def test_init_conv_kwargs(self, conv_kwargs, expectation): + with expectation: + self.cls(5, mode="conv", window_size=200, **conv_kwargs) + @pytest.mark.parametrize( "mode, ws, expectation", [ @@ -1202,7 +1281,10 @@ def test_set_params( for i in range(len(pars)): for j in range(i + 1, len(pars)): - with pytest.raises(AttributeError, match="can't set attribute 'mode'|property 'mode' of "): + with pytest.raises( + AttributeError, + match="can't set attribute 'mode'|property 'mode' of ", + ): par_set = { keys[i]: pars[keys[i]], keys[j]: pars[keys[j]], @@ -1778,6 +1860,44 @@ def test_init_mode(self, mode, expectation): with expectation: self.cls(5, mode=mode, window_size=window_size) + @pytest.mark.parametrize( + "conv_kwargs, expectation", + [ + (dict(), does_not_raise()), + ( + dict(axis=0), + pytest.raises( + ValueError, match="Setting the `axis` parameter is not allowed" + ), + ), + ( + dict(axis=1), + pytest.raises( + ValueError, match="Setting the `axis` parameter is not allowed" + ), + ), + (dict(shift=True), does_not_raise()), + ( + dict(shift=True, axis=0), + pytest.raises( + ValueError, match="Setting the `axis` parameter is not allowed" + ), + ), + ( + dict(shifts=True), + pytest.raises(ValueError, match="Unrecognized keyword arguments"), + ), + (dict(shift=True, predictor_causality="causal"), does_not_raise()), + ( + dict(shift=True, time_series=np.arange(10)), + pytest.raises(ValueError, match="Unrecognized keyword arguments"), + ), + ], + ) + def test_init_conv_kwargs(self, conv_kwargs, expectation): + with expectation: + self.cls(5, mode="conv", window_size=200, **conv_kwargs) + @pytest.mark.parametrize( "mode, ws, expectation", [ @@ -1836,7 +1956,10 @@ def test_set_params( for i in range(len(pars)): for j in range(i + 1, len(pars)): - with pytest.raises(AttributeError, match="can't set attribute 'mode'|property 'mode' of "): + with pytest.raises( + AttributeError, + match="can't set attribute 'mode'|property 'mode' of ", + ): par_set = { keys[i]: pars[keys[i]], keys[j]: pars[keys[j]], @@ -2490,6 +2613,50 @@ def test_init_mode(self, mode, expectation): with expectation: self.cls(5, mode=mode, window_size=window_size, decay_rates=np.arange(1, 6)) + @pytest.mark.parametrize( + "conv_kwargs, expectation", + [ + (dict(), does_not_raise()), + ( + dict(axis=0), + pytest.raises( + ValueError, match="Setting the `axis` parameter is not allowed" + ), + ), + ( + dict(axis=1), + pytest.raises( + ValueError, match="Setting the `axis` parameter is not allowed" + ), + ), + (dict(shift=True), does_not_raise()), + ( + dict(shift=True, axis=0), + pytest.raises( + ValueError, match="Setting the `axis` parameter is not allowed" + ), + ), + ( + dict(shifts=True), + pytest.raises(ValueError, match="Unrecognized keyword arguments"), + ), + (dict(shift=True, predictor_causality="causal"), does_not_raise()), + ( + dict(shift=True, time_series=np.arange(10)), + pytest.raises(ValueError, match="Unrecognized keyword arguments"), + ), + ], + ) + def test_init_conv_kwargs(self, conv_kwargs, expectation): + with expectation: + self.cls( + 5, + mode="conv", + window_size=200, + decay_rates=np.arange(1, 6), + **conv_kwargs, + ) + @pytest.mark.parametrize( "mode, ws, expectation", [ @@ -2557,7 +2724,10 @@ def test_set_params( for i in range(len(pars)): for j in range(i + 1, len(pars)): - with pytest.raises(AttributeError, match="can't set attribute 'mode'|property 'mode' of "): + with pytest.raises( + AttributeError, + match="can't set attribute 'mode'|property 'mode' of ", + ): par_set = { keys[i]: pars[keys[i]], keys[j]: pars[keys[j]], @@ -3078,6 +3248,44 @@ def test_init_mode(self, mode, expectation): with expectation: self.cls(5, mode=mode, window_size=window_size) + @pytest.mark.parametrize( + "conv_kwargs, expectation", + [ + (dict(), does_not_raise()), + ( + dict(axis=0), + pytest.raises( + ValueError, match="Setting the `axis` parameter is not allowed" + ), + ), + ( + dict(axis=1), + pytest.raises( + ValueError, match="Setting the `axis` parameter is not allowed" + ), + ), + (dict(shift=True), does_not_raise()), + ( + dict(shift=True, axis=0), + pytest.raises( + ValueError, match="Setting the `axis` parameter is not allowed" + ), + ), + ( + dict(shifts=True), + pytest.raises(ValueError, match="Unrecognized keyword arguments"), + ), + (dict(shift=True, predictor_causality="causal"), does_not_raise()), + ( + dict(shift=True, time_series=np.arange(10)), + pytest.raises(ValueError, match="Unrecognized keyword arguments"), + ), + ], + ) + def test_init_conv_kwargs(self, conv_kwargs, expectation): + with expectation: + self.cls(5, mode="conv", window_size=200, **conv_kwargs) + @pytest.mark.parametrize( "mode, ws, expectation", [ @@ -3136,7 +3344,10 @@ def test_set_params( for i in range(len(pars)): for j in range(i + 1, len(pars)): - with pytest.raises(AttributeError, match="can't set attribute 'mode'|property 'mode' of "): + with pytest.raises( + AttributeError, + match="can't set attribute 'mode'|property 'mode' of ", + ): par_set = { keys[i]: pars[keys[i]], keys[j]: pars[keys[j]], @@ -3792,7 +4003,10 @@ def test_set_params( for i in range(len(pars)): for j in range(i + 1, len(pars)): - with pytest.raises(AttributeError, match="can't set attribute 'mode'|property 'mode' of "): + with pytest.raises( + AttributeError, + match="can't set attribute 'mode'|property 'mode' of ", + ): par_set = { keys[i]: pars[keys[i]], keys[j]: pars[keys[j]],