diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2394576e..0e967ed2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,7 +18,7 @@ on: - main jobs: - tox: + tox_tests: if: ${{ !github.event.pull_request.draft }} strategy: matrix: @@ -38,7 +38,6 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install . pip install tox - name: Run tox @@ -49,6 +48,27 @@ jobs: env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + tox_check: + if: ${{ !github.event.pull_request.draft }} + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 # Use v4 for compatibility with pyproject.toml + with: + python-version: 3.12 + cache: pip + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install tox + + - name: Run tox + run: tox -e check + prevent_docs_absolute_links: runs-on: ubuntu-latest steps: @@ -63,7 +83,10 @@ jobs: check: if: ${{ !github.event.pull_request.draft }} - needs: tox + needs: + - tox_tests + - prevent_docs_absolute_links + - tox_check runs-on: ubuntu-latest steps: - name: Decide whether all tests and notebooks succeeded diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 5cbb99cb..6963b567 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -96,6 +96,19 @@ See below for details on how to [add tests](#adding-tests) and properly [documen Lastly, you should make sure that the existing tests all run successfully and that the codebase is formatted properly: +> [!TIP] +> The [NeMoS GitHub action](.github/workflows/ci.yml) runs tests and some additional style checks in an isolated environment using [`tox`](https://tox.wiki/en/). `tox` is not included in our optional dependencies, so if you want to replicate the action workflow locally, you need to install `tox` via pip and then run it. From the package directory: +> ```sh +> pip install tox +> tox -e check,py +> ``` +> This will execute `tox` with a Python version that matches your local environment. If the above passes, then the Github action will pass and your PR is mergeable +> +> You can also use `tox` to use `black` and `isort` to try and fix your code if either of those are failing. To do so, run `tox -e fix` +> +> `tox` configurations can be found in the [`tox.ini`](tox.ini) file. + + ```bash # run tests and make sure they all pass pytest tests/ @@ -105,7 +118,10 @@ pytest --doctest-modules src/nemos/ # format the code base black src/ -isort src +isort src --profile=black +isort docs/how_to_guide --profile=black +isort docs/background --profile=black +isort docs/tutorials --profile=black flake8 --config=tox.ini src ``` @@ -129,14 +145,6 @@ changes. Additionally, every PR to `main` or `development` will automatically run linters and tests through a [GitHub action](https://docs.github.com/en/actions). Merges can happen only when all check passes. -> [!NOTE] -> The [NeMoS GitHub action](.github/workflows/ci.yml) runs tests in an isolated environment using [`tox`](https://tox.wiki/en/). `tox` is not included in our optional dependencies, so if you want to replicate the action workflow locally, you need to install `tox` via pip and then run it. From the package directory: -> ```sh -> pip install tox -> tox -e py -> ``` -> This will execute `tox` with a Python version that matches your local environment. `tox` configurations can be found in the [`tox.ini`](tox.ini) file. - Once your changes are integrated, you will be added as a GitHub contributor and as one of the authors of the package. Thank you for being part of `nemos`! ### Style Guide diff --git a/pyproject.toml b/pyproject.toml index 5ee288b2..1ca97e44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dev = [ "isort", # Import sorter "pip-tools", # Dependency management "pytest", # Testing framework + "pytest-xdist", # Parallelize pytest "flake8", # Code linter "coverage", # Test coverage measurement "pytest-cov", # Test coverage plugin for pytest @@ -112,6 +113,7 @@ profile = "black" # Configure pytest [tool.pytest.ini_options] testpaths = ["tests"] # Specify the directory where test files are located +addopts = "-n auto" [tool.coverage.run] omit = [ diff --git a/src/nemos/base_class.py b/src/nemos/base_class.py index 67b63240..ba4a015a 100644 --- a/src/nemos/base_class.py +++ b/src/nemos/base_class.py @@ -25,7 +25,7 @@ class Base: Additionally, it has methods for selecting target devices and sending arrays to them. """ - def get_params(self, deep=True): + def get_params(self, deep=True) -> dict: """ From scikit-learn, get parameters by inspecting init. diff --git a/src/nemos/base_regressor.py b/src/nemos/base_regressor.py index 7d2cba7f..ba2782ed 100644 --- a/src/nemos/base_regressor.py +++ b/src/nemos/base_regressor.py @@ -140,6 +140,23 @@ def solver_run(self) -> Union[None, SolverRun]: """ return self._solver_run + def set_params(self, **params: Any): + """Manage warnings in case of multiple parameter settings.""" + # if both regularizer and regularizer_strength are set, then only + # warn in case the strength is not expected for the regularizer type + if "regularizer" in params and "regularizer_strength" in params: + reg = params.pop("regularizer") + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message="Caution: regularizer strength.*" + "|Unused parameter `regularizer_strength`.*", + ) + super().set_params(regularizer=reg) + + return super().set_params(**params) + @property def regularizer(self) -> Union[None, Regularizer]: """Getter for the regularizer attribute.""" @@ -171,19 +188,16 @@ def regularizer_strength(self) -> float: @regularizer_strength.setter def regularizer_strength(self, strength: Union[float, None]): - # if using unregularized, strength will be None no matter what - if isinstance(self._regularizer, UnRegularized): - self._regularizer_strength = None # check regularizer strength - elif strength is None: + if strength is None and not isinstance(self._regularizer, UnRegularized): warnings.warn( UserWarning( "Caution: regularizer strength has not been set. Defaulting to 1.0. Please see " "the documentation for best practices in setting regularization strength." ) ) - self._regularizer_strength = 1.0 - else: + strength = 1.0 + elif strength is not None: try: # force conversion to float to prevent weird GPU issues strength = float(strength) @@ -192,7 +206,16 @@ def regularizer_strength(self, strength: Union[float, None]): raise ValueError( f"Could not convert the regularizer strength: {strength} to a float." ) - self._regularizer_strength = strength + if isinstance(self._regularizer, UnRegularized): + warnings.warn( + UserWarning( + "Unused parameter `regularizer_strength` for UnRegularized GLM. " + "The regularizer strength parameter is not required and won't be used when the regularizer " + "is set to UnRegularized." + ) + ) + + self._regularizer_strength = strength @property def solver_name(self) -> str: diff --git a/src/nemos/basis.py b/src/nemos/basis.py index ecb69645..51127a31 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -357,7 +357,7 @@ def __sklearn_clone__(self) -> TransformerBasis: return cloned_obj def set_params(self, **parameters) -> TransformerBasis: - r""" + """ Set TransformerBasis parameters. When used with `sklearn.model_selection`, users can set either the `_basis` attribute directly @@ -497,37 +497,40 @@ def __init__( ) -> None: self.n_basis_funcs = n_basis_funcs self._n_input_dimensionality = 0 - self._check_n_basis_min() self._conv_kwargs = kwargs - self.bounds = bounds # check mode if mode not in ["conv", "eval"]: raise ValueError( f"`mode` should be either 'conv' or 'eval'. '{mode}' provided instead!" ) - if mode == "conv": - if window_size is None: - raise ValueError( - "If the basis is in `conv` mode, you must provide a window_size!" - ) - elif not (isinstance(window_size, int) and window_size > 0): - raise ValueError( - f"`window_size` must be a positive integer. {window_size} provided instead!" - ) - if bounds is not None: - raise ValueError("`bounds` should only be set when `mode=='eval'`.") - else: - if kwargs: - raise ValueError( - f"kwargs should only be set when mode=='conv', but '{mode}' provided instead!" - ) - self._window_size = window_size self._mode = mode + self.window_size = window_size + self.bounds = bounds + + if mode == "eval" and kwargs: + raise ValueError( + f"kwargs should only be set when mode=='conv', but '{mode}' provided instead!" + ) + self.kernel_ = None self._identifiability_constraints = False + @property + def n_basis_funcs(self): + return self._n_basis_funcs + + @n_basis_funcs.setter + def n_basis_funcs(self, value): + orig_n_basis = copy.deepcopy(getattr(self, "_n_basis_funcs", None)) + self._n_basis_funcs = value + try: + self._check_n_basis_min() + except ValueError as e: + self._n_basis_funcs = orig_n_basis + raise e + @property def bounds(self): return self._bounds @@ -535,16 +538,26 @@ def bounds(self): @bounds.setter def bounds(self, values: Union[None, Tuple[float, float]]): """Setter for bounds.""" + + if values is not None and self.mode == "conv": + raise ValueError("`bounds` should only be set when `mode=='eval'`.") + if values is not None and len(values) != 2: raise ValueError( f"The provided `bounds` must be of length two. Length {len(values)} provided instead!" ) + # convert to float and store try: self._bounds = values if values is None else tuple(map(float, values)) except (ValueError, TypeError): raise TypeError("Could not convert `bounds` to float.") + if values is not None and values[1] <= values[0]: + raise ValueError( + f"Invalid bound {values}. Lower bound is greater or equal than the upper bound." + ) + @property def mode(self): return self._mode @@ -553,6 +566,28 @@ def mode(self): def window_size(self): return self._window_size + @window_size.setter + def window_size(self, window_size): + """Setter for the window size parameter.""" + if self.mode == "eval": + if window_size: + raise ValueError( + "If basis is in `mode=='eval'`, `window_size` should be None." + ) + + else: + if window_size is None: + raise ValueError( + "If the basis is in `conv` mode, you must provide a window_size!" + ) + + elif not (isinstance(window_size, int) and window_size > 0): + raise ValueError( + f"`window_size` must be a positive integer. {window_size} provided instead!" + ) + + self._window_size = window_size + @property def identifiability_constraints(self): return self._identifiability_constraints @@ -1016,7 +1051,7 @@ def to_transformer(self) -> TransformerBasis: >>> # load some data >>> X, y = np.random.normal(size=(30, 1)), np.random.poisson(size=30) >>> basis = nmo.basis.RaisedCosineBasisLinear(10).to_transformer() - >>> glm = nmo.glm.GLM(regularizer="Ridge") + >>> glm = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=1.) >>> pipeline = Pipeline([("basis", basis), ("glm", glm)]) >>> param_grid = dict( ... glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6), @@ -1296,10 +1331,34 @@ def __init__( bounds=bounds, **kwargs, ) + self._n_input_dimensionality = 1 - if self.order < 1: + + @property + def order(self): + return self._order + + @order.setter + def order(self, value): + """Setter for the order parameter.""" + + if value < 1: raise ValueError("Spline order must be positive!") + # Set to None only the first time the setter is called. + orig_order = copy.deepcopy(getattr(self, "_order", None)) + + # Set the order + self._order = value + + # If the order was already initialized, re-check basis + if orig_order is not None: + try: + self._check_n_basis_min() + except ValueError as e: + self._order = orig_order + raise e + def _generate_knots( self, sample_pts: NDArray, @@ -2041,19 +2100,23 @@ def __init__( # The samples are scaled appropriately in the self._transform_samples which scales # and applies the log-stretch, no additional transform is needed. self._rescale_samples = False + if time_scaling is None: + time_scaling = 50.0 + self.time_scaling = time_scaling self.enforce_decay_to_zero = enforce_decay_to_zero - if time_scaling is None: - self._time_scaling = 50.0 - else: - self._check_time_scaling(time_scaling) - self._time_scaling = time_scaling @property def time_scaling(self): """Getter property for time_scaling.""" return self._time_scaling + @time_scaling.setter + def time_scaling(self, time_scaling): + """Setter property for time_scaling.""" + self._check_time_scaling(time_scaling) + self._time_scaling = time_scaling + @staticmethod def _check_time_scaling(time_scaling: float) -> None: if time_scaling <= 0: diff --git a/src/nemos/regularizer.py b/src/nemos/regularizer.py index 15a2181f..c2cc75d6 100644 --- a/src/nemos/regularizer.py +++ b/src/nemos/regularizer.py @@ -355,7 +355,7 @@ class GroupLasso(Regularizer): >>> # Create the GroupLasso regularizer instance >>> group_lasso = GroupLasso(mask=mask) >>> # fit a group-lasso glm - >>> model = GLM(regularizer=group_lasso).fit(X, y) + >>> model = GLM(regularizer=group_lasso, regularizer_strength=0.1).fit(X, y) >>> print(f"coeff shape: {model.coef_.shape}") coeff shape: (5,) """ diff --git a/tests/conftest.py b/tests/conftest.py index c945cbd6..8a36fadf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -398,7 +398,7 @@ def example_data_prox_operator(): ), ) regularizer_strength = 0.1 - mask = jnp.array([[1, 0, 1, 0], [0, 1, 0, 1]], dtype=jnp.float32) + mask = jnp.array([[1, 0, 1, 0], [0, 1, 0, 1]]).astype(float) scaling = 0.5 return params, regularizer_strength, mask, scaling diff --git a/tests/test_basis.py b/tests/test_basis.py index 4e320743..3e21db33 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -2,15 +2,14 @@ import inspect import pickle from contextlib import nullcontext as does_not_raise +from typing import Literal import jax.numpy import numpy as np import pynapple as nap import pytest -import sklearn.pipeline as pipeline import utils_testing from sklearn.base import clone as sk_clone -from sklearn.model_selection import GridSearchCV import nemos.basis as basis import nemos.convolve as convolve @@ -18,10 +17,11 @@ # automatic define user accessible basis and check the methods + def list_all_basis_classes() -> list[type]: """ Return all the classes in nemos.basis which are a subclass of Basis, - which should be all concrete classes except TransformerBasis. + which should be all concrete classes except TransformerBasis. """ return [ class_obj @@ -146,18 +146,32 @@ def test_sample_size_of_compute_features_matches_that_of_input( "samples, vmin, vmax, expectation", [ (0.5, 0, 1, does_not_raise()), - (-0.5, 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), + ( + -0.5, + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - ] + ( + np.linspace(-1, 0, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ( + np.linspace(1, 2, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ], ) def test_compute_features_vmin_vmax(self, samples, vmin, vmax, expectation): bas = self.cls(5, bounds=(vmin, vmax)) with expectation: bas(samples) - @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) def test_minimum_number_of_basis_required_is_matched( @@ -181,7 +195,9 @@ def test_minimum_number_of_basis_required_is_matched( @pytest.mark.parametrize("n_input", [0, 1, 2, 3]) @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) - def test_number_of_required_inputs_compute_features(self, n_input, mode, window_size): + def test_number_of_required_inputs_compute_features( + self, n_input, mode, window_size + ): """ Confirms that the compute_features() method correctly handles the number of input samples that are provided. """ @@ -389,8 +405,12 @@ def test_call_nan(self, mode, window_size): "samples, expectation", [ (np.array([0, 1, 2, 3, 4, 5]), does_not_raise()), - (np.array(['a', '1', '2', '3', '4', '5']), pytest.raises(TypeError, match="Input samples must")), - ]) + ( + np.array(["a", "1", "2", "3", "4", "5"]), + pytest.raises(TypeError, match="Input samples must"), + ), + ], + ) def test_call_input_type(self, samples, expectation): bas = self.cls(5) with expectation: @@ -489,12 +509,108 @@ def test_init_mode(self, mode, expectation): 1.5, pytest.raises(ValueError, match="`window_size` must be a positive "), ), + ("eval", None, does_not_raise()), + ( + "eval", + 10, + pytest.raises( + ValueError, + match=r"If basis is in `mode=='eval'`, `window_size` should be None", + ), + ), ], ) def test_init_window_size(self, mode, ws, expectation): with expectation: self.cls(5, mode=mode, window_size=ws) + @pytest.mark.parametrize( + "enforce_decay_to_zero, time_scaling, width, window_size, n_basis_funcs, bounds, mode", + [ + (False, 15, 4, None, 10, (1, 2), "eval"), + (False, 15, 4, 10, 10, None, "conv"), + ], + ) + def test_set_params( + self, + enforce_decay_to_zero, + time_scaling, + width, + window_size, + n_basis_funcs, + bounds, + mode: Literal["eval", "conv"], + ): + """Test the read-only and read/write property of the parameters.""" + pars = dict( + enforce_decay_to_zero=enforce_decay_to_zero, + time_scaling=time_scaling, + width=width, + window_size=window_size, + n_basis_funcs=n_basis_funcs, + bounds=bounds, + ) + keys = list(pars.keys()) + bas = self.cls( + enforce_decay_to_zero=enforce_decay_to_zero, + time_scaling=time_scaling, + width=width, + window_size=window_size, + n_basis_funcs=n_basis_funcs, + mode=mode, + ) + for i in range(len(pars)): + for j in range(i + 1, len(pars)): + par_set = {keys[i]: pars[keys[i]], keys[j]: pars[keys[j]]} + bas = bas.set_params(**par_set) + assert isinstance(bas, self.cls) + + for i in range(len(pars)): + for j in range(i + 1, len(pars)): + with pytest.raises(AttributeError, match="can't set attribute 'mode'|property 'mode' of "): + par_set = { + keys[i]: pars[keys[i]], + keys[j]: pars[keys[j]], + "mode": mode, + } + bas.set_params(**par_set) + + @pytest.mark.parametrize( + "mode, expectation", + [ + ("eval", does_not_raise()), + ("conv", pytest.raises(ValueError, match="`bounds` should only be set")), + ], + ) + def test_set_bounds(self, mode, expectation): + ws = dict(eval=None, conv=10) + with expectation: + self.cls(window_size=ws[mode], n_basis_funcs=10, mode=mode, bounds=(1, 2)) + + bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv", bounds=None) + with pytest.raises(ValueError, match="`bounds` should only be set"): + bas.set_params(bounds=(1, 2)) + + @pytest.mark.parametrize( + "mode, expectation", + [ + ("conv", does_not_raise()), + ("eval", pytest.raises(ValueError, match="If basis is in `mode=='eval'`")), + ], + ) + def test_set_window_size(self, mode, expectation): + """Test window size set behavior.""" + with expectation: + self.cls(window_size=10, n_basis_funcs=10, mode=mode) + + bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv") + with pytest.raises(ValueError, match="If the basis is in `conv` mode"): + bas.set_params(window_size=None) + + bas = self.cls(window_size=None, n_basis_funcs=10, mode="eval") + with pytest.raises(ValueError, match="If basis is in `mode=='eval'`"): + bas.set_params(window_size=10) + def test_convolution_is_performed(self): bas = self.cls(5, mode="conv", window_size=10) x = np.random.normal(size=100) @@ -523,24 +639,46 @@ def test_conv_kwargs_error(self): with pytest.raises(ValueError, match="kwargs should only be set"): self.cls(5, mode="eval", test="hi") - @pytest.mark.parametrize( "bounds, expectation", [ (None, does_not_raise()), - ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), - ((1, None), pytest.raises(TypeError, match=r"Could not convert")), + ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), + ((1, None), pytest.raises(TypeError, match=r"Could not convert")), ((1, 3), does_not_raise()), (("a", 3), pytest.raises(TypeError, match="Could not convert")), ((1, "a"), pytest.raises(TypeError, match="Could not convert")), - (("a", "a"), pytest.raises(TypeError, match="Could not convert")) - ] + (("a", "a"), pytest.raises(TypeError, match="Could not convert")), + ], ) def test_vmin_vmax_init(self, bounds, expectation): with expectation: bas = self.cls(3, bounds=bounds) assert bounds == bas.bounds if bounds else bas.bounds is None + @pytest.mark.parametrize( + "bounds, expectation", + [ + (None, does_not_raise()), + ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), + ((1, None), pytest.raises(TypeError, match=r"Could not convert")), + ((1, 3), does_not_raise()), + (("a", 3), pytest.raises(TypeError, match="Could not convert")), + ((1, "a"), pytest.raises(TypeError, match="Could not convert")), + (("a", "a"), pytest.raises(TypeError, match="Could not convert")), + ( + (2, 1), + pytest.raises( + ValueError, match=r"Invalid bound \(2, 1\). Lower bound is greater" + ), + ), + ], + ) + def test_vmin_vmax_setter(self, bounds, expectation): + bas = self.cls(3, bounds=(1, 3)) + with expectation: + bas.set_params(bounds=bounds) + assert bounds == bas.bounds if bounds else bas.bounds is None @pytest.mark.parametrize( "vmin, vmax, samples, nan_idx", @@ -548,8 +686,8 @@ def test_vmin_vmax_init(self, bounds, expectation): (None, None, np.arange(5), []), (0, 3, np.arange(5), [4]), (1, 4, np.arange(5), [0]), - (1, 3, np.arange(5), [0, 4]) - ] + (1, 3, np.arange(5), [0, 4]), + ], ) def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx): bounds = None if vmin is None else (vmin, vmax) @@ -564,12 +702,14 @@ def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx): [ (0, 3, np.arange(5), [4]), (1, 4, np.arange(5), [0]), - (1, 3, np.arange(5), [0, 4]) - ] + (1, 3, np.arange(5), [0, 4]), + ], ) - def test_vmin_vmax_eval_on_grid_no_effect_on_eval(self, vmin, vmax, samples, nan_idx): - bas_no_range = self.cls(3, mode="eval", window_size=10, bounds=None) - bas = self.cls(3, mode="eval", window_size=10, bounds=(vmin, vmax)) + def test_vmin_vmax_eval_on_grid_no_effect_on_eval( + self, vmin, vmax, samples, nan_idx + ): + bas_no_range = self.cls(3, mode="eval", bounds=None) + bas = self.cls(3, mode="eval", bounds=(vmin, vmax)) _, out1 = bas.evaluate_on_grid(10) _, out2 = bas_no_range.evaluate_on_grid(10) assert np.allclose(out1, out2) @@ -580,8 +720,8 @@ def test_vmin_vmax_eval_on_grid_no_effect_on_eval(self, vmin, vmax, samples, nan (None, np.arange(5), [4], 0, 1), ((0, 3), np.arange(5), [4], 0, 3), ((1, 4), np.arange(5), [0], 1, 4), - ((1, 3), np.arange(5), [0, 4], 1, 3) - ] + ((1, 3), np.arange(5), [0, 4], 1, 3), + ], ) def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx): bas_no_range = self.cls(3, mode="eval", bounds=None) @@ -596,8 +736,8 @@ def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx (None, np.arange(5), does_not_raise()), ((0, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), ((1, 4), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), - ((1, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")) - ] + ((1, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), + ], ) def test_vmin_vmax_mode_conv(self, bounds, samples, exception): with exception: @@ -685,11 +825,26 @@ def test_sample_size_of_compute_features_matches_that_of_input( "samples, vmin, vmax, expectation", [ (0.5, 0, 1, does_not_raise()), - (-0.5, 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), + ( + -0.5, + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - ] + ( + np.linspace(-1, 0, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ( + np.linspace(1, 2, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ], ) def test_compute_features_vmin_vmax(self, samples, vmin, vmax, expectation): bas = self.cls(5, bounds=(vmin, vmax)) @@ -719,7 +874,9 @@ def test_minimum_number_of_basis_required_is_matched( @pytest.mark.parametrize("n_input", [0, 1, 2, 3]) @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) - def test_number_of_required_inputs_compute_features(self, n_input, mode, window_size): + def test_number_of_required_inputs_compute_features( + self, n_input, mode, window_size + ): """ Confirms that the compute_features() method correctly handles the number of input samples that are provided. """ @@ -868,8 +1025,12 @@ def test_call_nan(self, mode, window_size): "samples, expectation", [ (np.array([0, 1, 2, 3, 4, 5]), does_not_raise()), - (np.array(['a', '1', '2', '3', '4', '5']), pytest.raises(TypeError, match="Input samples must")), - ]) + ( + np.array(["a", "1", "2", "3", "4", "5"]), + pytest.raises(TypeError, match="Input samples must"), + ), + ], + ) def test_call_input_type(self, samples, expectation): bas = self.cls(5) with expectation: @@ -885,11 +1046,26 @@ def test_call_equivalent_in_conv(self): "samples, vmin, vmax, expectation", [ (0.5, 0, 1, does_not_raise()), - (-0.5, 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), + ( + -0.5, + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - ] + ( + np.linspace(-1, 0, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ( + np.linspace(1, 2, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ], ) def test_call_vmin_vmax(self, samples, vmin, vmax, expectation): bas = self.cls(5, bounds=(vmin, vmax)) @@ -982,12 +1158,94 @@ def test_init_mode(self, mode, expectation): 1.5, pytest.raises(ValueError, match="`window_size` must be a positive "), ), + ("eval", None, does_not_raise()), + ( + "eval", + 10, + pytest.raises( + ValueError, + match=r"If basis is in `mode=='eval'`, `window_size` should be None", + ), + ), ], ) def test_init_window_size(self, mode, ws, expectation): with expectation: self.cls(5, mode=mode, window_size=ws) + @pytest.mark.parametrize( + "width, window_size, n_basis_funcs, bounds, mode", + [ + (4, None, 10, (1, 2), "eval"), + (4, 10, 10, None, "conv"), + ], + ) + def test_set_params( + self, width, window_size, n_basis_funcs, bounds, mode: Literal["eval", "conv"] + ): + """Test the read-only and read/write property of the parameters.""" + pars = dict( + width=width, + window_size=window_size, + n_basis_funcs=n_basis_funcs, + bounds=bounds, + ) + keys = list(pars.keys()) + bas = self.cls( + width=width, window_size=window_size, n_basis_funcs=n_basis_funcs, mode=mode + ) + for i in range(len(pars)): + for j in range(i + 1, len(pars)): + par_set = {keys[i]: pars[keys[i]], keys[j]: pars[keys[j]]} + bas.set_params(**par_set) + assert isinstance(bas, self.cls) + + for i in range(len(pars)): + for j in range(i + 1, len(pars)): + with pytest.raises(AttributeError, match="can't set attribute 'mode'|property 'mode' of "): + par_set = { + keys[i]: pars[keys[i]], + keys[j]: pars[keys[j]], + "mode": mode, + } + bas.set_params(**par_set) + + @pytest.mark.parametrize( + "mode, expectation", + [ + ("eval", does_not_raise()), + ("conv", pytest.raises(ValueError, match="`bounds` should only be set")), + ], + ) + def test_set_bounds(self, mode, expectation): + ws = dict(eval=None, conv=10) + with expectation: + self.cls(window_size=ws[mode], n_basis_funcs=10, mode=mode, bounds=(1, 2)) + + bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv", bounds=None) + with pytest.raises(ValueError, match="`bounds` should only be set"): + bas.set_params(bounds=(1, 2)) + + @pytest.mark.parametrize( + "mode, expectation", + [ + ("conv", does_not_raise()), + ("eval", pytest.raises(ValueError, match="If basis is in `mode=='eval'`")), + ], + ) + def test_set_window_size(self, mode, expectation): + """Test window size set behavior.""" + with expectation: + self.cls(window_size=10, n_basis_funcs=10, mode=mode) + + bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv") + with pytest.raises(ValueError, match="If the basis is in `conv` mode"): + bas.set_params(window_size=None) + + bas = self.cls(window_size=None, n_basis_funcs=10, mode="eval") + with pytest.raises(ValueError, match="If basis is in `mode=='eval'`"): + bas.set_params(window_size=10) + def test_convolution_is_performed(self): bas = self.cls(5, mode="conv", window_size=10) x = np.random.normal(size=100) @@ -1024,22 +1282,46 @@ def test_conv_kwargs_error(self): ((1, 3), does_not_raise()), (("a", 3), pytest.raises(TypeError, match="Could not convert")), ((1, "a"), pytest.raises(TypeError, match="Could not convert")), - (("a", "a"), pytest.raises(TypeError, match="Could not convert")) - ] + (("a", "a"), pytest.raises(TypeError, match="Could not convert")), + ], ) def test_vmin_vmax_init(self, bounds, expectation): with expectation: bas = self.cls(3, bounds=bounds) assert bounds == bas.bounds if bounds else bas.bounds is None + @pytest.mark.parametrize( + "bounds, expectation", + [ + (None, does_not_raise()), + ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), + ((1, None), pytest.raises(TypeError, match=r"Could not convert")), + ((1, 3), does_not_raise()), + (("a", 3), pytest.raises(TypeError, match="Could not convert")), + ((1, "a"), pytest.raises(TypeError, match="Could not convert")), + (("a", "a"), pytest.raises(TypeError, match="Could not convert")), + ( + (2, 1), + pytest.raises( + ValueError, match=r"Invalid bound \(2, 1\). Lower bound is greater" + ), + ), + ], + ) + def test_vmin_vmax_setter(self, bounds, expectation): + bas = self.cls(5, bounds=(1, 3)) + with expectation: + bas.set_params(bounds=bounds) + assert bounds == bas.bounds if bounds else bas.bounds is None + @pytest.mark.parametrize( "vmin, vmax, samples, nan_idx", [ (None, None, np.arange(5), []), (0, 3, np.arange(5), [4]), (1, 4, np.arange(5), [0]), - (1, 3, np.arange(5), [0, 4]) - ] + (1, 3, np.arange(5), [0, 4]), + ], ) def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx): bounds = None if vmin is None else (vmin, vmax) @@ -1054,10 +1336,12 @@ def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx): [ (0, 3, np.arange(5), [4]), (1, 4, np.arange(5), [0]), - (1, 3, np.arange(5), [0, 4]) - ] + (1, 3, np.arange(5), [0, 4]), + ], ) - def test_vmin_vmax_eval_on_grid_no_effect_on_eval(self, vmin, vmax, samples, nan_idx): + def test_vmin_vmax_eval_on_grid_no_effect_on_eval( + self, vmin, vmax, samples, nan_idx + ): bas_no_range = self.cls(3, mode="eval", bounds=None) bas = self.cls(3, mode="eval", bounds=(vmin, vmax)) _, out1 = bas.evaluate_on_grid(10) @@ -1070,8 +1354,8 @@ def test_vmin_vmax_eval_on_grid_no_effect_on_eval(self, vmin, vmax, samples, nan (None, np.arange(5), [4], 0, 1), ((0, 3), np.arange(5), [4], 0, 3), ((1, 4), np.arange(5), [0], 1, 4), - ((1, 3), np.arange(5), [0, 4], 1, 3) - ] + ((1, 3), np.arange(5), [0, 4], 1, 3), + ], ) def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx): bas_no_range = self.cls(3, mode="eval", bounds=None) @@ -1086,8 +1370,8 @@ def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx (None, np.arange(5), does_not_raise()), ((0, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), ((1, 4), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), - ((1, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")) - ] + ((1, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), + ], ) def test_vmin_vmax_mode_conv(self, bounds, samples, exception): with exception: @@ -1175,11 +1459,26 @@ def test_sample_size_of_compute_features_matches_that_of_input( "samples, vmin, vmax, expectation", [ (0.5, 0, 1, does_not_raise()), - (-0.5, 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), + ( + -0.5, + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - ] + ( + np.linspace(-1, 0, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ( + np.linspace(1, 2, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ], ) def test_compute_features_vmin_vmax(self, samples, vmin, vmax, expectation): bas = self.cls(5, bounds=(vmin, vmax)) @@ -1233,7 +1532,9 @@ def test_samples_range_matches_compute_features_requirements( @pytest.mark.parametrize("n_input", [0, 1, 2, 3]) @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) - def test_number_of_required_inputs_compute_features(self, n_input, mode, window_size): + def test_number_of_required_inputs_compute_features( + self, n_input, mode, window_size + ): """ Confirms that the compute_features() method correctly handles the number of input samples that are provided. """ @@ -1363,8 +1664,12 @@ def test_call_nan(self, mode, window_size): "samples, expectation", [ (np.array([0, 1, 2, 3, 4, 5]), does_not_raise()), - (np.array(['a', '1', '2', '3', '4', '5']), pytest.raises(TypeError, match="Input samples must")), - ]) + ( + np.array(["a", "1", "2", "3", "4", "5"]), + pytest.raises(TypeError, match="Input samples must"), + ), + ], + ) def test_call_input_type(self, samples, expectation): bas = self.cls(5) with expectation: @@ -1380,11 +1685,26 @@ def test_call_equivalent_in_conv(self): "samples, vmin, vmax, expectation", [ (0.5, 0, 1, does_not_raise()), - (-0.5, 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), + ( + -0.5, + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - ] + ( + np.linspace(-1, 0, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ( + np.linspace(1, 2, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ], ) def test_call_vmin_vmax(self, samples, vmin, vmax, expectation): bas = self.cls(5, bounds=(vmin, vmax)) @@ -1472,12 +1792,94 @@ def test_init_mode(self, mode, expectation): 1.5, pytest.raises(ValueError, match="`window_size` must be a positive "), ), + ("eval", None, does_not_raise()), + ( + "eval", + 10, + pytest.raises( + ValueError, + match=r"If basis is in `mode=='eval'`, `window_size` should be None", + ), + ), ], ) def test_init_window_size(self, mode, ws, expectation): with expectation: self.cls(5, mode=mode, window_size=ws) + @pytest.mark.parametrize( + "order, window_size, n_basis_funcs, bounds, mode", + [ + (4, None, 10, (1, 2), "eval"), + (4, 10, 10, None, "conv"), + ], + ) + def test_set_params( + self, order, window_size, n_basis_funcs, bounds, mode: Literal["eval", "conv"] + ): + """Test the read-only and read/write property of the parameters.""" + pars = dict( + order=order, + window_size=window_size, + n_basis_funcs=n_basis_funcs, + bounds=bounds, + ) + keys = list(pars.keys()) + bas = self.cls( + order=order, window_size=window_size, n_basis_funcs=n_basis_funcs, mode=mode + ) + for i in range(len(pars)): + for j in range(i + 1, len(pars)): + par_set = {keys[i]: pars[keys[i]], keys[j]: pars[keys[j]]} + bas.set_params(**par_set) + assert isinstance(bas, self.cls) + + for i in range(len(pars)): + for j in range(i + 1, len(pars)): + with pytest.raises(AttributeError, match="can't set attribute 'mode'|property 'mode' of "): + par_set = { + keys[i]: pars[keys[i]], + keys[j]: pars[keys[j]], + "mode": mode, + } + bas.set_params(**par_set) + + @pytest.mark.parametrize( + "mode, expectation", + [ + ("eval", does_not_raise()), + ("conv", pytest.raises(ValueError, match="`bounds` should only be set")), + ], + ) + def test_set_bounds(self, mode, expectation): + ws = dict(eval=None, conv=10) + with expectation: + self.cls(window_size=ws[mode], n_basis_funcs=10, mode=mode, bounds=(1, 2)) + + bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv", bounds=None) + with pytest.raises(ValueError, match="`bounds` should only be set"): + bas.set_params(bounds=(1, 2)) + + @pytest.mark.parametrize( + "mode, expectation", + [ + ("conv", does_not_raise()), + ("eval", pytest.raises(ValueError, match="If basis is in `mode=='eval'`")), + ], + ) + def test_set_window_size(self, mode, expectation): + """Test window size set behavior.""" + with expectation: + self.cls(window_size=10, n_basis_funcs=10, mode=mode) + + bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv") + with pytest.raises(ValueError, match="If the basis is in `conv` mode"): + bas.set_params(window_size=None) + + bas = self.cls(window_size=None, n_basis_funcs=10, mode="eval") + with pytest.raises(ValueError, match="If basis is in `mode=='eval'`"): + bas.set_params(window_size=10) + def test_convolution_is_performed(self): bas = self.cls(5, mode="conv", window_size=10) x = np.random.normal(size=100) @@ -1505,7 +1907,6 @@ def test_conv_kwargs_error(self): with pytest.raises(ValueError, match="kwargs should only be set"): self.cls(5, mode="eval", test="hi") - @pytest.mark.parametrize( "bounds, expectation", [ @@ -1515,22 +1916,52 @@ def test_conv_kwargs_error(self): ((1, 3), does_not_raise()), (("a", 3), pytest.raises(TypeError, match="Could not convert")), ((1, "a"), pytest.raises(TypeError, match="Could not convert")), - (("a", "a"), pytest.raises(TypeError, match="Could not convert")) - ] + (("a", "a"), pytest.raises(TypeError, match="Could not convert")), + ( + (2, 1), + pytest.raises( + ValueError, match=r"Invalid bound \(2, 1\). Lower bound is greater" + ), + ), + ], ) def test_vmin_vmax_init(self, bounds, expectation): with expectation: bas = self.cls(3, bounds=bounds) assert bounds == bas.bounds if bounds else bas.bounds is None + @pytest.mark.parametrize( + "bounds, expectation", + [ + (None, does_not_raise()), + ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), + ((1, None), pytest.raises(TypeError, match=r"Could not convert")), + ((1, 3), does_not_raise()), + (("a", 3), pytest.raises(TypeError, match="Could not convert")), + ((1, "a"), pytest.raises(TypeError, match="Could not convert")), + (("a", "a"), pytest.raises(TypeError, match="Could not convert")), + ( + (2, 1), + pytest.raises( + ValueError, match=r"Invalid bound \(2, 1\). Lower bound is greater" + ), + ), + ], + ) + def test_vmin_vmax_setter(self, bounds, expectation): + bas = self.cls(3, bounds=(1, 3)) + with expectation: + bas.set_params(bounds=bounds) + assert bounds == bas.bounds if bounds else bas.bounds is None + @pytest.mark.parametrize( "vmin, vmax, samples, nan_idx", [ (None, None, np.arange(5), []), (0, 3, np.arange(5), [4]), (1, 4, np.arange(5), [0]), - (1, 3, np.arange(5), [0, 4]) - ] + (1, 3, np.arange(5), [0, 4]), + ], ) def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx): bounds = None if vmin is None else (vmin, vmax) @@ -1545,10 +1976,12 @@ def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx): [ (None, np.arange(5), [4], 1), ((1, 4), np.arange(5), [0], 3), - ((1, 3), np.arange(5), [0, 4], 2) - ] + ((1, 3), np.arange(5), [0, 4], 2), + ], ) - def test_vmin_vmax_eval_on_grid_scaling_effect_on_eval(self, bounds, samples, nan_idx, scaling): + def test_vmin_vmax_eval_on_grid_scaling_effect_on_eval( + self, bounds, samples, nan_idx, scaling + ): """Check that the MSpline has the expected scaling property.""" bas_no_range = self.cls(3, mode="eval", bounds=None) bas = self.cls(3, mode="eval", bounds=bounds) @@ -1565,8 +1998,8 @@ def test_vmin_vmax_eval_on_grid_scaling_effect_on_eval(self, bounds, samples, na (None, np.arange(5), [4], 0, 1), ((0, 3), np.arange(5), [4], 0, 3), ((1, 4), np.arange(5), [0], 1, 4), - ((1, 3), np.arange(5), [0, 4], 1, 3) - ] + ((1, 3), np.arange(5), [0, 4], 1, 3), + ], ) def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx): bas_no_range = self.cls(3, mode="eval", bounds=None) @@ -1581,8 +2014,8 @@ def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx (None, np.arange(5), does_not_raise()), ((0, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), ((1, 4), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), - ((1, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")) - ] + ((1, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), + ], ) def test_vmin_vmax_mode_conv(self, bounds, samples, exception): with exception: @@ -1596,6 +2029,7 @@ def test_transformer_get_params(self): params_basis = bas.get_params() assert params_transf == params_basis + class TestOrthExponentialBasis(BasisFuncsTesting): cls = basis.OrthExponentialBasis @@ -1684,13 +2118,38 @@ def test_sample_size_of_compute_features_matches_that_of_input( @pytest.mark.parametrize( "samples, vmin, vmax, expectation", [ - (np.linspace(-0.5, -0.001, 7), 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), - (np.linspace(1.5, 2., 7), 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), - ([-0.5, -0.1, -0.01, 1.5, 2 , 3], 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), + ( + np.linspace(-0.5, -0.001, 7), + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), + ( + np.linspace(1.5, 2.0, 7), + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), + ( + [-0.5, -0.1, -0.01, 1.5, 2, 3], + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - ] + ( + np.linspace(-1, 0, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ( + np.linspace(1, 2, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ], ) def test_compute_features_vmin_vmax(self, samples, vmin, vmax, expectation): bas = self.cls(5, bounds=(vmin, vmax), decay_rates=np.linspace(0.1, 1, 5)) @@ -1727,7 +2186,9 @@ def test_minimum_number_of_basis_required_is_matched( @pytest.mark.parametrize("n_input", [0, 1, 2, 3]) @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 10)]) - def test_number_of_required_inputs_compute_features(self, n_input, mode, window_size): + def test_number_of_required_inputs_compute_features( + self, n_input, mode, window_size + ): """Tests whether the compute_features method correctly processes the number of required inputs.""" basis_obj = self.cls( n_basis_funcs=5, @@ -1902,8 +2363,12 @@ def test_call_nan(self, mode, window_size): "samples, expectation", [ (np.array([0, 1, 2, 3, 4, 5]), does_not_raise()), - (np.array(['a', '1', '2', '3', '4', '5']), pytest.raises(TypeError, match="Input samples must")), - ]) + ( + np.array(["a", "1", "2", "3", "4", "5"]), + pytest.raises(TypeError, match="Input samples must"), + ), + ], + ) def test_call_input_type(self, samples, expectation): bas = self.cls(5, np.linspace(0.1, 1, 5)) with expectation: @@ -1918,14 +2383,29 @@ def test_call_equivalent_in_conv(self): @pytest.mark.parametrize( "samples, vmin, vmax, expectation", [ - (np.linspace(-1,-0.5, 10), 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), + ( + np.linspace(-1, -0.5, 10), + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - ] + ( + np.linspace(-1, 0, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ( + np.linspace(1, 2, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ], ) def test_call_vmin_vmax(self, samples, vmin, vmax, expectation): - bas = self.cls(5, decay_rates=np.linspace(0,1,5), bounds=(vmin, vmax)) + bas = self.cls(5, decay_rates=np.linspace(0, 1, 5), bounds=(vmin, vmax)) with expectation: bas(samples) @@ -2005,31 +2485,145 @@ def test_transform_fails(self): ), ], ) - def test_init_mode(self, mode, expectation): - window_size = None if mode == "eval" else 10 + def test_init_mode(self, mode, expectation): + window_size = None if mode == "eval" else 10 + with expectation: + self.cls(5, mode=mode, window_size=window_size, decay_rates=np.arange(1, 6)) + + @pytest.mark.parametrize( + "mode, ws, expectation", + [ + ("conv", 2, does_not_raise()), + ("conv", 10, does_not_raise()), + ( + "conv", + -1, + pytest.raises(ValueError, match="`window_size` must be a positive "), + ), + ( + "conv", + 1.5, + pytest.raises(ValueError, match="`window_size` must be a positive "), + ), + ("eval", None, does_not_raise()), + ( + "eval", + 10, + pytest.raises( + ValueError, + match=r"If basis is in `mode=='eval'`, `window_size` should be None", + ), + ), + ], + ) + def test_init_window_size(self, mode, ws, expectation): + with expectation: + self.cls(5, mode=mode, window_size=ws, decay_rates=np.arange(1, 6)) + + @pytest.mark.parametrize( + "decay_rates, window_size, n_basis_funcs, bounds, mode", + [ + (np.arange(1, 11), None, 10, (1, 2), "eval"), + (np.arange(1, 11), 10, 10, None, "conv"), + ], + ) + def test_set_params( + self, + decay_rates, + window_size, + n_basis_funcs, + bounds, + mode: Literal["eval", "conv"], + ): + """Test the read-only and read/write property of the parameters.""" + pars = dict( + decay_rates=decay_rates, + window_size=window_size, + n_basis_funcs=n_basis_funcs, + bounds=bounds, + ) + keys = list(pars.keys()) + bas = self.cls( + decay_rates=decay_rates, + window_size=window_size, + n_basis_funcs=n_basis_funcs, + mode=mode, + ) + for i in range(len(pars)): + for j in range(i + 1, len(pars)): + par_set = {keys[i]: pars[keys[i]], keys[j]: pars[keys[j]]} + bas.set_params(**par_set) + assert isinstance(bas, self.cls) + + for i in range(len(pars)): + for j in range(i + 1, len(pars)): + with pytest.raises(AttributeError, match="can't set attribute 'mode'|property 'mode' of "): + par_set = { + keys[i]: pars[keys[i]], + keys[j]: pars[keys[j]], + "mode": mode, + } + bas.set_params(**par_set) + + @pytest.mark.parametrize( + "mode, expectation", + [ + ("eval", does_not_raise()), + ("conv", pytest.raises(ValueError, match="`bounds` should only be set")), + ], + ) + def test_set_bounds(self, mode, expectation): + ws = dict(eval=None, conv=10) with expectation: - self.cls(5, mode=mode, window_size=window_size, decay_rates=np.arange(1, 6)) + self.cls( + decay_rates=np.arange(1, 11), + window_size=ws[mode], + n_basis_funcs=10, + mode=mode, + bounds=(1, 2), + ) + + bas = self.cls( + decay_rates=np.arange(1, 11), + window_size=10, + n_basis_funcs=10, + mode="conv", + bounds=None, + ) + with pytest.raises(ValueError, match="`bounds` should only be set"): + bas.set_params(bounds=(1, 2)) @pytest.mark.parametrize( - "mode, ws, expectation", + "mode, expectation", [ - ("conv", 2, does_not_raise()), - ("conv", 10, does_not_raise()), - ( - "conv", - -1, - pytest.raises(ValueError, match="`window_size` must be a positive "), - ), - ( - "conv", - 1.5, - pytest.raises(ValueError, match="`window_size` must be a positive "), - ), + ("conv", does_not_raise()), + ("eval", pytest.raises(ValueError, match="If basis is in `mode=='eval'`")), ], ) - def test_init_window_size(self, mode, ws, expectation): + def test_set_window_size(self, mode, expectation): + """Test window size set behavior.""" with expectation: - self.cls(5, mode=mode, window_size=ws, decay_rates=np.arange(1, 6)) + self.cls( + decay_rates=np.arange(1, 11), + window_size=10, + n_basis_funcs=10, + mode=mode, + ) + + bas = self.cls( + decay_rates=np.arange(1, 11), window_size=10, n_basis_funcs=10, mode="conv" + ) + with pytest.raises(ValueError, match="If the basis is in `conv` mode"): + bas.set_params(window_size=None) + + bas = self.cls( + decay_rates=np.arange(1, 11), + window_size=None, + n_basis_funcs=10, + mode="eval", + ) + with pytest.raises(ValueError, match="If basis is in `mode=='eval'`"): + bas.set_params(window_size=10) def test_convolution_is_performed(self): bas = self.cls(5, mode="conv", window_size=10, decay_rates=np.arange(1, 6)) @@ -2069,6 +2663,7 @@ def test_transformer_get_params(self): assert params_transf == params_basis assert np.all(rates_transf == rates_basis) + class TestBSplineBasis(BasisFuncsTesting): cls = basis.BSplineBasis @@ -2143,11 +2738,26 @@ def test_sample_size_of_compute_features_matches_that_of_input( "samples, vmin, vmax, expectation", [ (0.5, 0, 1, does_not_raise()), - (-0.5, 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), + ( + -0.5, + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - ] + ( + np.linspace(-1, 0, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ( + np.linspace(1, 2, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ], ) def test_compute_features_vmin_vmax(self, samples, vmin, vmax, expectation): bas = self.cls(5, bounds=(vmin, vmax)) @@ -2216,7 +2826,9 @@ def test_samples_range_matches_compute_features_requirements( @pytest.mark.parametrize("n_input", [0, 1, 2, 3]) @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 10)]) - def test_number_of_required_inputs_compute_features(self, n_input, mode, window_size): + def test_number_of_required_inputs_compute_features( + self, n_input, mode, window_size + ): """ Confirms that the compute_features() method correctly handles the number of input samples that are provided. """ @@ -2350,8 +2962,12 @@ def test_call_nan(self, mode, window_size): "samples, expectation", [ (np.array([0, 1, 2, 3, 4, 5]), does_not_raise()), - (np.array(['a', '1', '2', '3', '4', '5']), pytest.raises(TypeError, match="Input samples must")), - ]) + ( + np.array(["a", "1", "2", "3", "4", "5"]), + pytest.raises(TypeError, match="Input samples must"), + ), + ], + ) def test_call_input_type(self, samples, expectation): bas = self.cls(5) with expectation: @@ -2367,11 +2983,26 @@ def test_call_equivalent_in_conv(self): "samples, vmin, vmax, expectation", [ (0.5, 0, 1, does_not_raise()), - (-0.5, 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), + ( + -0.5, + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - ] + ( + np.linspace(-1, 0, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ( + np.linspace(1, 2, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ], ) def test_call_vmin_vmax(self, samples, vmin, vmax, expectation): bas = self.cls(5, bounds=(vmin, vmax)) @@ -2461,12 +3092,94 @@ def test_init_mode(self, mode, expectation): 1.5, pytest.raises(ValueError, match="`window_size` must be a positive "), ), + ("eval", None, does_not_raise()), + ( + "eval", + 10, + pytest.raises( + ValueError, + match=r"If basis is in `mode=='eval'`, `window_size` should be None", + ), + ), ], ) def test_init_window_size(self, mode, ws, expectation): with expectation: self.cls(5, mode=mode, window_size=ws) + @pytest.mark.parametrize( + "order, window_size, n_basis_funcs, bounds, mode", + [ + (3, None, 10, (1, 2), "eval"), + (3, 10, 10, None, "conv"), + ], + ) + def test_set_params( + self, order, window_size, n_basis_funcs, bounds, mode: Literal["eval", "conv"] + ): + """Test the read-only and read/write property of the parameters.""" + pars = dict( + order=order, + window_size=window_size, + n_basis_funcs=n_basis_funcs, + bounds=bounds, + ) + keys = list(pars.keys()) + bas = self.cls( + order=order, window_size=window_size, n_basis_funcs=n_basis_funcs, mode=mode + ) + for i in range(len(pars)): + for j in range(i + 1, len(pars)): + par_set = {keys[i]: pars[keys[i]], keys[j]: pars[keys[j]]} + bas.set_params(**par_set) + assert isinstance(bas, self.cls) + + for i in range(len(pars)): + for j in range(i + 1, len(pars)): + with pytest.raises(AttributeError, match="can't set attribute 'mode'|property 'mode' of "): + par_set = { + keys[i]: pars[keys[i]], + keys[j]: pars[keys[j]], + "mode": mode, + } + bas.set_params(**par_set) + + @pytest.mark.parametrize( + "mode, expectation", + [ + ("eval", does_not_raise()), + ("conv", pytest.raises(ValueError, match="`bounds` should only be set")), + ], + ) + def test_set_bounds(self, mode, expectation): + ws = dict(eval=None, conv=10) + with expectation: + self.cls(window_size=ws[mode], n_basis_funcs=10, mode=mode, bounds=(1, 2)) + + bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv", bounds=None) + with pytest.raises(ValueError, match="`bounds` should only be set"): + bas.set_params(bounds=(1, 2)) + + @pytest.mark.parametrize( + "mode, expectation", + [ + ("conv", does_not_raise()), + ("eval", pytest.raises(ValueError, match="If basis is in `mode=='eval'`")), + ], + ) + def test_set_window_size(self, mode, expectation): + """Test window size set behavior.""" + with expectation: + self.cls(window_size=10, n_basis_funcs=10, mode=mode) + + bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv") + with pytest.raises(ValueError, match="If the basis is in `conv` mode"): + bas.set_params(window_size=None) + + bas = self.cls(window_size=None, n_basis_funcs=10, mode="eval") + with pytest.raises(ValueError, match="If basis is in `mode=='eval'`"): + bas.set_params(window_size=10) + def test_convolution_is_performed(self): bas = self.cls(5, mode="conv", window_size=10) x = np.random.normal(size=100) @@ -2494,7 +3207,6 @@ def test_conv_kwargs_error(self): with pytest.raises(ValueError, match="kwargs should only be set"): self.cls(5, mode="eval", test="hi") - @pytest.mark.parametrize( "bounds, expectation", [ @@ -2504,22 +3216,46 @@ def test_conv_kwargs_error(self): ((1, 3), does_not_raise()), (("a", 3), pytest.raises(TypeError, match="Could not convert")), ((1, "a"), pytest.raises(TypeError, match="Could not convert")), - (("a", "a"), pytest.raises(TypeError, match="Could not convert")) - ] + (("a", "a"), pytest.raises(TypeError, match="Could not convert")), + ], ) def test_vmin_vmax_init(self, bounds, expectation): with expectation: bas = self.cls(5, bounds=bounds) assert bounds == bas.bounds if bounds else bas.bounds is None + @pytest.mark.parametrize( + "bounds, expectation", + [ + (None, does_not_raise()), + ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), + ((1, None), pytest.raises(TypeError, match=r"Could not convert")), + ((1, 3), does_not_raise()), + (("a", 3), pytest.raises(TypeError, match="Could not convert")), + ((1, "a"), pytest.raises(TypeError, match="Could not convert")), + (("a", "a"), pytest.raises(TypeError, match="Could not convert")), + ( + (2, 1), + pytest.raises( + ValueError, match=r"Invalid bound \(2, 1\). Lower bound is greater" + ), + ), + ], + ) + def test_vmin_vmax_setter(self, bounds, expectation): + bas = self.cls(5, bounds=(1, 3)) + with expectation: + bas.set_params(bounds=bounds) + assert bounds == bas.bounds if bounds else bas.bounds is None + @pytest.mark.parametrize( "vmin, vmax, samples, nan_idx", [ (None, None, np.arange(5), []), (0, 3, np.arange(5), [4]), (1, 4, np.arange(5), [0]), - (1, 3, np.arange(5), [0, 4]) - ] + (1, 3, np.arange(5), [0, 4]), + ], ) def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx): bounds = None if vmin is None else (vmin, vmax) @@ -2534,10 +3270,12 @@ def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx): [ (0, 3, np.arange(5), [4]), (1, 4, np.arange(5), [0]), - (1, 3, np.arange(5), [0, 4]) - ] + (1, 3, np.arange(5), [0, 4]), + ], ) - def test_vmin_vmax_eval_on_grid_no_effect_on_eval(self, vmin, vmax, samples, nan_idx): + def test_vmin_vmax_eval_on_grid_no_effect_on_eval( + self, vmin, vmax, samples, nan_idx + ): bas_no_range = self.cls(5, mode="eval", bounds=None) bas = self.cls(5, mode="eval", bounds=(vmin, vmax)) _, out1 = bas.evaluate_on_grid(10) @@ -2550,8 +3288,8 @@ def test_vmin_vmax_eval_on_grid_no_effect_on_eval(self, vmin, vmax, samples, nan (None, np.arange(5), [4], 0, 1), ((0, 3), np.arange(5), [4], 0, 3), ((1, 4), np.arange(5), [0], 1, 4), - ((1, 3), np.arange(5), [0, 4], 1, 3) - ] + ((1, 3), np.arange(5), [0, 4], 1, 3), + ], ) def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx): bas_no_range = self.cls(5, mode="eval", bounds=None) @@ -2566,8 +3304,8 @@ def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx (None, np.arange(5), does_not_raise()), ((0, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), ((1, 4), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), - ((1, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")) - ] + ((1, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), + ], ) def test_vmin_vmax_mode_conv(self, bounds, samples, exception): with exception: @@ -2656,11 +3394,26 @@ def test_sample_size_of_compute_features_matches_that_of_input( "samples, vmin, vmax, expectation", [ (0.5, 0, 1, does_not_raise()), - (-0.5, 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), + ( + -0.5, + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - ] + ( + np.linspace(-1, 0, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ( + np.linspace(1, 2, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ], ) def test_compute_features_vmin_vmax(self, samples, vmin, vmax, expectation): bas = self.cls(5, bounds=(vmin, vmax)) @@ -2747,7 +3500,9 @@ def test_samples_range_matches_compute_features_requirements( @pytest.mark.parametrize("n_input", [0, 1, 2, 3]) @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 10)]) - def test_number_of_required_inputs_compute_features(self, n_input, mode, window_size): + def test_number_of_required_inputs_compute_features( + self, n_input, mode, window_size + ): """ Confirms that the compute_features() method correctly handles the number of input samples that are provided. """ @@ -2881,8 +3636,12 @@ def test_call_nan(self, mode, window_size): "samples, expectation", [ (np.array([0, 1, 2, 3, 4, 5]), does_not_raise()), - (np.array(['a', '1', '2', '3', '4', '5']), pytest.raises(TypeError, match="Input samples must")), - ]) + ( + np.array(["a", "1", "2", "3", "4", "5"]), + pytest.raises(TypeError, match="Input samples must"), + ), + ], + ) def test_call_input_type(self, samples, expectation): bas = self.cls(5) with expectation: @@ -2898,26 +3657,26 @@ def test_call_equivalent_in_conv(self): "samples, vmin, vmax, expectation", [ (0.5, 0, 1, does_not_raise()), - (-0.5, 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), + ( + -0.5, + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - ] - ) - def test_call_vmin_vmax(self, samples, vmin, vmax, expectation): - bas = self.cls(5, bounds=(vmin, vmax)) - with expectation: - bas(samples) - - @pytest.mark.parametrize( - "samples, vmin, vmax, expectation", - [ - (0.5, 0, 1, does_not_raise()), - (-0.5, 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), - (np.linspace(-1,1,10), 0, 1, does_not_raise()), - (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - ] + ( + np.linspace(-1, 0, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ( + np.linspace(1, 2, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ], ) def test_call_vmin_vmax(self, samples, vmin, vmax, expectation): bas = self.cls(5, bounds=(vmin, vmax)) @@ -2989,12 +3748,94 @@ def test_transform_fails(self): 1.5, pytest.raises(ValueError, match="`window_size` must be a positive "), ), + ("eval", None, does_not_raise()), + ( + "eval", + 10, + pytest.raises( + ValueError, + match=r"If basis is in `mode=='eval'`, `window_size` should be None", + ), + ), ], ) def test_init_window_size(self, mode, ws, expectation): with expectation: self.cls(5, mode=mode, window_size=ws) + @pytest.mark.parametrize( + "order, window_size, n_basis_funcs, bounds, mode", + [ + (3, None, 10, (1, 2), "eval"), + (3, 10, 10, None, "conv"), + ], + ) + def test_set_params( + self, order, window_size, n_basis_funcs, bounds, mode: Literal["eval", "conv"] + ): + """Test the read-only and read/write property of the parameters.""" + pars = dict( + order=order, + window_size=window_size, + n_basis_funcs=n_basis_funcs, + bounds=bounds, + ) + keys = list(pars.keys()) + bas = self.cls( + order=order, window_size=window_size, n_basis_funcs=n_basis_funcs, mode=mode + ) + for i in range(len(pars)): + for j in range(i + 1, len(pars)): + par_set = {keys[i]: pars[keys[i]], keys[j]: pars[keys[j]]} + bas.set_params(**par_set) + assert isinstance(bas, self.cls) + + for i in range(len(pars)): + for j in range(i + 1, len(pars)): + with pytest.raises(AttributeError, match="can't set attribute 'mode'|property 'mode' of "): + par_set = { + keys[i]: pars[keys[i]], + keys[j]: pars[keys[j]], + "mode": mode, + } + bas.set_params(**par_set) + + @pytest.mark.parametrize( + "mode, expectation", + [ + ("eval", does_not_raise()), + ("conv", pytest.raises(ValueError, match="`bounds` should only be set")), + ], + ) + def test_set_bounds(self, mode, expectation): + ws = dict(eval=None, conv=10) + with expectation: + self.cls(window_size=ws[mode], n_basis_funcs=10, mode=mode, bounds=(1, 2)) + + bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv", bounds=None) + with pytest.raises(ValueError, match="`bounds` should only be set"): + bas.set_params(bounds=(1, 2)) + + @pytest.mark.parametrize( + "mode, expectation", + [ + ("conv", does_not_raise()), + ("eval", pytest.raises(ValueError, match="If basis is in `mode=='eval'`")), + ], + ) + def test_set_window_size(self, mode, expectation): + """Test window size set behavior.""" + with expectation: + self.cls(window_size=10, n_basis_funcs=10, mode=mode) + + bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv") + with pytest.raises(ValueError, match="If the basis is in `conv` mode"): + bas.set_params(window_size=None) + + bas = self.cls(window_size=None, n_basis_funcs=10, mode="eval") + with pytest.raises(ValueError, match="If basis is in `mode=='eval'`"): + bas.set_params(window_size=10) + def test_convolution_is_performed(self): bas = self.cls(5, mode="conv", window_size=10) x = np.random.normal(size=100) @@ -3022,7 +3863,6 @@ def test_conv_kwargs_error(self): with pytest.raises(ValueError, match="kwargs should only be set"): self.cls(5, mode="eval", test="hi") - @pytest.mark.parametrize( "bounds, expectation", [ @@ -3032,22 +3872,46 @@ def test_conv_kwargs_error(self): ((1, 3), does_not_raise()), (("a", 3), pytest.raises(TypeError, match="Could not convert")), ((1, "a"), pytest.raises(TypeError, match="Could not convert")), - (("a", "a"), pytest.raises(TypeError, match="Could not convert")) - ] + (("a", "a"), pytest.raises(TypeError, match="Could not convert")), + ], ) def test_vmin_vmax_init(self, bounds, expectation): with expectation: bas = self.cls(5, bounds=bounds) assert bounds == bas.bounds if bounds else bas.bounds is None + @pytest.mark.parametrize( + "bounds, expectation", + [ + (None, does_not_raise()), + ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), + ((1, None), pytest.raises(TypeError, match=r"Could not convert")), + ((1, 3), does_not_raise()), + (("a", 3), pytest.raises(TypeError, match="Could not convert")), + ((1, "a"), pytest.raises(TypeError, match="Could not convert")), + (("a", "a"), pytest.raises(TypeError, match="Could not convert")), + ( + (2, 1), + pytest.raises( + ValueError, match=r"Invalid bound \(2, 1\). Lower bound is greater" + ), + ), + ], + ) + def test_vmin_vmax_setter(self, bounds, expectation): + bas = self.cls(5, bounds=(1, 3)) + with expectation: + bas.set_params(bounds=bounds) + assert bounds == bas.bounds if bounds else bas.bounds is None + @pytest.mark.parametrize( "vmin, vmax, samples, nan_idx", [ (None, None, np.arange(5), []), (0, 3, np.arange(5), [4]), (1, 4, np.arange(5), [0]), - (1, 3, np.arange(5), [0, 4]) - ] + (1, 3, np.arange(5), [0, 4]), + ], ) def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx): bounds = None if vmin is None else (vmin, vmax) @@ -3062,10 +3926,12 @@ def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx): [ (0, 3, np.arange(5), [4]), (1, 4, np.arange(5), [0]), - (1, 3, np.arange(5), [0, 4]) - ] + (1, 3, np.arange(5), [0, 4]), + ], ) - def test_vmin_vmax_eval_on_grid_no_effect_on_eval(self, vmin, vmax, samples, nan_idx): + def test_vmin_vmax_eval_on_grid_no_effect_on_eval( + self, vmin, vmax, samples, nan_idx + ): bas_no_range = self.cls(5, mode="eval", bounds=None) bas = self.cls(5, mode="eval", bounds=(vmin, vmax)) _, out1 = bas.evaluate_on_grid(10) @@ -3078,8 +3944,8 @@ def test_vmin_vmax_eval_on_grid_no_effect_on_eval(self, vmin, vmax, samples, nan (None, np.arange(5), [4], 0, 1), ((0, 3), np.arange(5), [4], 0, 3), ((1, 4), np.arange(5), [0], 1, 4), - ((1, 3), np.arange(5), [0, 4], 1, 3) - ] + ((1, 3), np.arange(5), [0, 4], 1, 3), + ], ) def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx): bas_no_range = self.cls(5, mode="eval", bounds=None) @@ -3094,8 +3960,8 @@ def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx (None, np.arange(5), does_not_raise()), ((0, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), ((1, 4), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), - ((1, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")) - ] + ((1, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), + ], ) def test_vmin_vmax_mode_conv(self, bounds, samples, exception): with exception: @@ -3109,6 +3975,7 @@ def test_transformer_get_params(self): params_basis = bas.get_params() assert params_transf == params_basis + class CombinedBasis(BasisFuncsTesting): """ This class is used to run tests on combination operations (e.g., addition, multiplication) among Basis functions. @@ -3122,6 +3989,10 @@ class CombinedBasis(BasisFuncsTesting): @staticmethod def instantiate_basis(n_basis, basis_class, mode="eval", window_size=10): """Instantiate and return two basis of the type specified.""" + + if mode == "eval": + window_size = None + if basis_class == basis.MSplineBasis: basis_obj = basis_class( n_basis_funcs=n_basis, order=4, mode=mode, window_size=window_size @@ -3244,7 +4115,7 @@ def test_sample_size_of_compute_features_matches_that_of_input( self, n_basis_a, n_basis_b, sample_size, basis_a, basis_b, mode, window_size ): """ - Test whether the output sample size from the `AdditiveBasis` compute_features function matches the input sample size. + Test whether the output sample size from `AdditiveBasis` compute_features function matches input sample size. """ basis_a_obj = self.instantiate_basis( n_basis_a, basis_a, mode=mode, window_size=window_size @@ -3258,7 +4129,8 @@ def test_sample_size_of_compute_features_matches_that_of_input( ) if eval_basis.shape[0] != sample_size: raise ValueError( - f"Dimensions do not agree: The window size should match the second dimension of the output features basis." + f"Dimensions do not agree: The window size should match the second dimension of the " + f"output features basis." f"The window size is {sample_size}", f"The second dimension of the output features basis is {eval_basis.shape[0]}", ) @@ -3574,7 +4446,11 @@ def test_call_non_empty( @pytest.mark.parametrize( "mn, mx, expectation", - [(0, 1, does_not_raise()), (-2, 2, does_not_raise()), (0.1, 2, does_not_raise())], + [ + (0, 1, does_not_raise()), + (-2, 2, does_not_raise()), + (0.1, 2, does_not_raise()), + ], ) @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @@ -4095,7 +4971,11 @@ def test_call_non_empty( @pytest.mark.parametrize( "mn, mx, expectation", - [(0, 1, does_not_raise()), (-2, 2, does_not_raise()), (0.1, 2, does_not_raise())], + [ + (0, 1, does_not_raise()), + (-2, 2, does_not_raise()), + (0.1, 2, does_not_raise()), + ], ) @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @@ -4243,6 +5123,7 @@ def test_basis_to_transformer(basis_cls): for k in bas.__dict__.keys(): assert getattr(bas, k) == getattr(trans_bas, k) + @pytest.mark.parametrize( "basis_cls", [ @@ -4289,7 +5170,11 @@ def test_to_transformer_and_constructor_are_equivalent(basis_cls): trans_bas_b = basis.TransformerBasis(bas) # they both just have a _basis - assert list(trans_bas_a.__dict__.keys()) == list(trans_bas_b.__dict__.keys()) == ["_basis"] + assert ( + list(trans_bas_a.__dict__.keys()) + == list(trans_bas_b.__dict__.keys()) + == ["_basis"] + ) # and those bases are the same assert trans_bas_a._basis.__dict__ == trans_bas_b._basis.__dict__ @@ -4349,7 +5234,7 @@ def test_transformerbasis_getattr(basis_cls, n_basis_funcs): @pytest.mark.parametrize("n_basis_funcs_new", [6, 10, 20]) def test_transformerbasis_set_params(basis_cls, n_basis_funcs_init, n_basis_funcs_new): trans_basis = basis.TransformerBasis(basis_cls(n_basis_funcs_init)) - trans_basis.set_params(n_basis_funcs = n_basis_funcs_new) + trans_basis.set_params(n_basis_funcs=n_basis_funcs_new) assert trans_basis.n_basis_funcs == n_basis_funcs_new assert trans_basis._basis.n_basis_funcs == n_basis_funcs_new @@ -4374,6 +5259,7 @@ def test_transformerbasis_setattr_basis(basis_cls): assert trans_bas._basis.n_basis_funcs == 20 assert isinstance(trans_bas._basis, basis_cls) + @pytest.mark.parametrize( "basis_cls", [ @@ -4394,6 +5280,7 @@ def test_transformerbasis_setattr_basis_attribute(basis_cls): assert trans_bas._basis.n_basis_funcs == 20 assert isinstance(trans_bas._basis, basis_cls) + @pytest.mark.parametrize( "basis_cls", [ @@ -4415,7 +5302,7 @@ def test_transformerbasis_copy_basis_on_contsruct(basis_cls): assert trans_bas._basis.n_basis_funcs == 20 assert trans_bas._basis.n_basis_funcs == 20 assert isinstance(trans_bas._basis, basis_cls) - + @pytest.mark.parametrize( "basis_cls", @@ -4432,7 +5319,10 @@ def test_transformerbasis_setattr_illegal_attribute(basis_cls): # is not allowed trans_bas = basis.TransformerBasis(basis_cls(10)) - with pytest.raises(ValueError, match="Only setting _basis or existing attributes of _basis is allowed."): + with pytest.raises( + ValueError, + match="Only setting _basis or existing attributes of _basis is allowed.", + ): trans_bas.random_attr = "random value" @@ -4454,11 +5344,18 @@ def test_transformerbasis_addition(basis_cls): trans_bas_sum = trans_bas_a + trans_bas_b assert isinstance(trans_bas_sum, basis.TransformerBasis) assert isinstance(trans_bas_sum._basis, basis.AdditiveBasis) - assert trans_bas_sum.n_basis_funcs == trans_bas_a.n_basis_funcs + trans_bas_b.n_basis_funcs - assert trans_bas_sum._n_input_dimensionality == trans_bas_a._n_input_dimensionality + trans_bas_b._n_input_dimensionality + assert ( + trans_bas_sum.n_basis_funcs + == trans_bas_a.n_basis_funcs + trans_bas_b.n_basis_funcs + ) + assert ( + trans_bas_sum._n_input_dimensionality + == trans_bas_a._n_input_dimensionality + trans_bas_b._n_input_dimensionality + ) assert trans_bas_sum._basis1.n_basis_funcs == n_basis_funcs_a assert trans_bas_sum._basis2.n_basis_funcs == n_basis_funcs_b + @pytest.mark.parametrize( "basis_cls", [ @@ -4477,11 +5374,18 @@ def test_transformerbasis_multiplication(basis_cls): trans_bas_prod = trans_bas_a * trans_bas_b assert isinstance(trans_bas_prod, basis.TransformerBasis) assert isinstance(trans_bas_prod._basis, basis.MultiplicativeBasis) - assert trans_bas_prod.n_basis_funcs == trans_bas_a.n_basis_funcs * trans_bas_b.n_basis_funcs - assert trans_bas_prod._n_input_dimensionality == trans_bas_a._n_input_dimensionality + trans_bas_b._n_input_dimensionality + assert ( + trans_bas_prod.n_basis_funcs + == trans_bas_a.n_basis_funcs * trans_bas_b.n_basis_funcs + ) + assert ( + trans_bas_prod._n_input_dimensionality + == trans_bas_a._n_input_dimensionality + trans_bas_b._n_input_dimensionality + ) assert trans_bas_prod._basis1.n_basis_funcs == n_basis_funcs_a assert trans_bas_prod._basis2.n_basis_funcs == n_basis_funcs_b + @pytest.mark.parametrize( "basis_cls", [ @@ -4493,23 +5397,26 @@ def test_transformerbasis_multiplication(basis_cls): ], ) @pytest.mark.parametrize( - "exponent, error_type, error_message", - [ - (2, does_not_raise, None), - (5, does_not_raise, None), - (0.5, TypeError, "Exponent should be an integer"), - (-1, ValueError, "Exponent should be a non-negative integer") - ] + "exponent, error_type, error_message", + [ + (2, does_not_raise, None), + (5, does_not_raise, None), + (0.5, TypeError, "Exponent should be an integer"), + (-1, ValueError, "Exponent should be a non-negative integer"), + ], ) -def test_transformerbasis_exponentiation(basis_cls, exponent: int, error_type, error_message): +def test_transformerbasis_exponentiation( + basis_cls, exponent: int, error_type, error_message +): trans_bas = basis.TransformerBasis(basis_cls(5)) if not isinstance(exponent, int): with pytest.raises(error_type, match=error_message): - trans_bas_exp = trans_bas ** exponent + trans_bas_exp = trans_bas**exponent assert isinstance(trans_bas_exp, basis.TransformerBasis) assert isinstance(trans_bas_exp._basis, basis.MultiplicativeBasis) + @pytest.mark.parametrize( "basis_cls", [ @@ -4522,11 +5429,17 @@ def test_transformerbasis_exponentiation(basis_cls, exponent: int, error_type, e ) def test_transformerbasis_dir(basis_cls): trans_bas = basis.TransformerBasis(basis_cls(5)) - for attr_name in ("fit", "transform", "fit_transform", "n_basis_funcs", "mode", "window_size"): + for attr_name in ( + "fit", + "transform", + "fit_transform", + "n_basis_funcs", + "mode", + "window_size", + ): assert attr_name in dir(trans_bas) - @pytest.mark.parametrize( "basis_cls", [ @@ -4602,7 +5515,7 @@ def test_transformerbasis_pickle(tmpdir, basis_cls, n_basis_funcs): (2, False, "anti-causal", [20, 75]), (2, None, "anti-causal", [20, 19, 75, 74]), (3, False, "acausal", [0, 20, 50, 75]), - (5, False, "acausal", [ 0, 1, 19, 20, 50, 51, 74, 75]), + (5, False, "acausal", [0, 1, 19, 20, 50, 51, 74, 75]), ], ) @pytest.mark.parametrize( @@ -4690,7 +5603,7 @@ def test_multi_epoch_pynapple_basis( (2, False, "anti-causal", [20, 75]), (2, None, "anti-causal", [20, 19, 75, 74]), (3, False, "acausal", [0, 20, 50, 75]), - (5, False, "acausal", [0, 1, 19, 20, 50, 51, 74, 75]), + (5, False, "acausal", [0, 1, 19, 20, 50, 51, 74, 75]), ], ) @pytest.mark.parametrize( diff --git a/tests/test_convergence.py b/tests/test_convergence.py index 881ee0ac..5a8814e1 100644 --- a/tests/test_convergence.py +++ b/tests/test_convergence.py @@ -76,11 +76,12 @@ def test_ridge_convergence(solver_names): y = np.random.poisson(rate) # instantiate and fit ridge GLM with GradientDescent - model_GD = nmo.glm.GLM(regularizer="Ridge", solver_kwargs=dict(tol=10**-12)) + model_GD = nmo.glm.GLM(regularizer_strength=1., regularizer="Ridge", solver_kwargs=dict(tol=10**-12)) model_GD.fit(X, y) # instantiate and fit ridge GLM with ProximalGradient model_PG = nmo.glm.GLM( + regularizer_strength=1., regularizer="Ridge", solver_name="ProximalGradient", solver_kwargs=dict(tol=10**-12), @@ -108,6 +109,7 @@ def test_lasso_convergence(solver_name): # instantiate and fit GLM with ProximalGradient model_PG = nmo.glm.GLM( regularizer="Lasso", + regularizer_strength=1., solver_name="ProximalGradient", solver_kwargs=dict(tol=10**-12), ) diff --git a/tests/test_convolution.py b/tests/test_convolution.py index a23f8bb0..4c5b75ec 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -329,7 +329,7 @@ def test_tree_structure_match(self, trial_counts, axis): conv = convolve.create_convolutional_predictor( basis_matrix, trial_counts, axis=axis ) - assert jax.tree_util.tree_structure(trial_counts) == jax.tree_structure(conv) + assert jax.tree_util.tree_structure(trial_counts) == jax.tree_util.tree_structure(conv) @pytest.mark.parametrize("axis", [0, 1, 2]) @pytest.mark.parametrize( @@ -346,7 +346,6 @@ def test_tree_structure_match(self, trial_counts, axis): (2, False, "anti-causal", [29]), (2, None, "anti-causal", [29, 28]), (3, False, "acausal", [29, 0]), - (2, False, "acausal", [29]), ], ) def test_expected_nan(self, axis, window_size, shift, predictor_causality, nan_idx): @@ -394,7 +393,6 @@ def test_expected_nan(self, axis, window_size, shift, predictor_causality, nan_i (2, False, "anti-causal", [20, 75]), (2, None, "anti-causal", [20, 19, 75, 74]), (3, False, "acausal", [0, 20, 50, 75]), - (2, False, "acausal", [20, 75]), ], ) def test_multi_epoch_pynapple( diff --git a/tests/test_glm.py b/tests/test_glm.py index 5c979a94..d17ae3a9 100644 --- a/tests/test_glm.py +++ b/tests/test_glm.py @@ -73,7 +73,7 @@ def test_solver_type(self, regularizer, solver_name, expectation, glm_class): Test that an error is raised if a non-compatible solver is passed. """ with expectation: - glm_class(regularizer=regularizer, solver_name=solver_name) + glm_class(regularizer=regularizer, solver_name=solver_name, regularizer_strength=1) @pytest.mark.parametrize( "observation, expectation", @@ -166,7 +166,7 @@ def test_get_params(self): assert list(model.get_params().values()) == expected_values # changing regularizer - model.regularizer = "Ridge" + model.set_params(regularizer="Ridge", regularizer_strength=1.) expected_values = [ model.observation_model.inverse_link_function, @@ -491,6 +491,7 @@ def test_fit_mask_grouplasso(self, group_sparse_poisson_glm_model_instantiation) """Test that the group lasso fit goes through""" X, y, model, params, rate, mask = group_sparse_poisson_glm_model_instantiation model.set_params( + regularizer_strength=1., regularizer=nmo.regularizer.GroupLasso(mask=mask), solver_name="ProximalGradient", ) @@ -1176,6 +1177,7 @@ def test_initialize_solver_mask_grouplasso( model.set_params( regularizer=nmo.regularizer.GroupLasso(mask=mask), solver_name="ProximalGradient", + regularizer_strength=1., ) params = model.initialize_params(X, y) model.initialize_state(X, y, params) @@ -1481,10 +1483,11 @@ def test_glm_update_consistent_with_fit_with_svrg(self, request, regr_setup, glm n_features = sum(x.shape[1] for x in jax.tree.leaves(X)) regularizer_kwargs["mask"] = (np.random.randn(n_features) > 0).reshape(1, -1).astype(float) + reg = regularizer_class(**regularizer_kwargs) + strength = None if isinstance(reg, nmo.regularizer.UnRegularized) else 1. glm = glm_class( - regularizer=regularizer_class( - **regularizer_kwargs, - ), + regularizer=reg, + regularizer_strength=strength, solver_name=solver_name, solver_kwargs={ "batch_size": batch_size, @@ -1495,9 +1498,7 @@ def test_glm_update_consistent_with_fit_with_svrg(self, request, regr_setup, glm }, ) glm2 = glm_class( - regularizer=regularizer_class( - **regularizer_kwargs, - ), + regularizer=reg, solver_name=solver_name, solver_kwargs={ "batch_size": batch_size, @@ -1506,6 +1507,7 @@ def test_glm_update_consistent_with_fit_with_svrg(self, request, regr_setup, glm "maxiter": maxiter, "key": key, }, + regularizer_strength=strength, ) glm2.fit(X, y) @@ -1623,7 +1625,8 @@ def test_estimate_dof_resid( Test that the dof is an integer. """ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation - model.regularizer = reg + strength = None if isinstance(reg, nmo.regularizer.UnRegularized) else 1. + model.set_params(regularizer=reg, regularizer_strength=strength) model.solver_name = model.regularizer.default_solver model.fit(X, y) num = model._estimate_resid_degrees_of_freedom(X, n_samples=n_samples) @@ -1642,15 +1645,68 @@ def test_warning_solver_reg_str(self, reg): model = nmo.glm.GLM(regularizer=reg, regularizer_strength=1.0) # reset to unregularized - model.regularizer = "UnRegularized" + model.set_params(regularizer = "UnRegularized", regularizer_strength=None) with pytest.warns(UserWarning): nmo.glm.GLM(regularizer=reg) @pytest.mark.parametrize("reg", ["Ridge", "Lasso", "GroupLasso"]) def test_reg_strength_reset(self, reg): model = nmo.glm.GLM(regularizer=reg, regularizer_strength=1.0) - model.regularizer = "UnRegularized" - assert model.regularizer_strength is None + with pytest.warns(UserWarning, match="Unused parameter `regularizer_strength` for UnRegularized GLM"): + model.regularizer = "UnRegularized" + model.regularizer_strength = None + with pytest.warns(UserWarning, match="Caution: regularizer strength has not been set"): + model.regularizer = "Ridge" + + @pytest.mark.parametrize( + "params, warns", + [ + # set regularizer + ({"regularizer": "Ridge"}, + pytest.warns(UserWarning, match="Caution: regularizer strength has not been set")), + ({"regularizer": "Lasso"}, + pytest.warns(UserWarning, match="Caution: regularizer strength has not been set")), + ({"regularizer": "GroupLasso"}, + pytest.warns(UserWarning, match="Caution: regularizer strength has not been set")), + ({"regularizer": "UnRegularized"}, does_not_raise()), + # set both None or number + ({"regularizer": "Ridge", "regularizer_strength": None}, + pytest.warns(UserWarning, match="Caution: regularizer strength has not been set")), + ({"regularizer": "Ridge", "regularizer_strength": 1.}, does_not_raise()), + ({"regularizer": "Lasso", "regularizer_strength": None}, + pytest.warns(UserWarning, match="Caution: regularizer strength has not been set")), + ({"regularizer": "Lasso", "regularizer_strength": 1.}, does_not_raise()), + ({"regularizer": "GroupLasso", "regularizer_strength": None}, + pytest.warns(UserWarning, match="Caution: regularizer strength has not been set")), + ({"regularizer": "GroupLasso", "regularizer_strength": 1.}, does_not_raise()), + ({"regularizer": "UnRegularized", "regularizer_strength": None}, does_not_raise()), + ({"regularizer": "UnRegularized", "regularizer_strength": 1.}, + pytest.warns(UserWarning, match="Unused parameter `regularizer_strength` for UnRegularized GLM")), + # set regularizer str only + ({"regularizer_strength": 1.}, + pytest.warns(UserWarning, match="Unused parameter `regularizer_strength` for UnRegularized GLM")), + ({"regularizer_strength": None}, does_not_raise()), + ] + ) + def test_reg_set_params(self, params, warns): + model = nmo.glm.GLM() + with warns: + model.set_params(**params) + + @pytest.mark.parametrize( + "params, warns", + [ + # set regularizer str only + ({"regularizer_strength": 1.}, does_not_raise()), + ({"regularizer_strength": None}, + pytest.warns(UserWarning, match="Caution: regularizer strength has not been set")), + ] + ) + @pytest.mark.parametrize("reg", ["Ridge", "Lasso", "GroupLasso"]) + def test_reg_set_params_reg_str_only(self, params, warns, reg): + model = nmo.glm.GLM(regularizer=reg, regularizer_strength=1) + with warns: + model.set_params(**params) class TestPopulationGLM: @@ -1684,7 +1740,7 @@ def test_solver_type(self, regularizer, expectation, population_glm_class): Test that an error is raised if a non-compatible solver is passed. """ with expectation: - population_glm_class(regularizer=regularizer) + population_glm_class(regularizer=regularizer, regularizer_strength=1.) def test_get_params(self): """ @@ -1732,7 +1788,7 @@ def test_get_params(self): assert list(model.get_params().values()) == expected_values # changing regularizer - model.regularizer = "Ridge" + model.set_params(regularizer="Ridge", regularizer_strength=1.) expected_values = [ model.feature_mask, @@ -1792,7 +1848,7 @@ def test_init_observation_type( """ with expectation: population_glm_class( - regularizer=ridge_regularizer, observation_model=observation + regularizer=ridge_regularizer, observation_model=observation, regularizer_strength=1. ) @pytest.mark.parametrize( @@ -1857,7 +1913,8 @@ def test_estimate_dof_resid( Test that the dof is an integer. """ X, y, model, true_params, firing_rate = poisson_population_GLM_model - model.regularizer = reg + strength = None if isinstance(reg, nmo.regularizer.UnRegularized) else 1. + model.set_params(regularizer=reg, regularizer_strength=strength) model.solver_name = model.regularizer.default_solver model.fit(X, y) num = model._estimate_resid_degrees_of_freedom(X, n_samples=n_samples) @@ -2123,6 +2180,7 @@ def test_fit_mask_grouplasso(self, group_sparse_poisson_glm_model_instantiation) model.set_params( regularizer=nmo.regularizer.GroupLasso(mask=mask), solver_name="ProximalGradient", + regularizer_strength=1., ) model.fit(X, y) @@ -2527,6 +2585,7 @@ def test_initialize_solver_mask_grouplasso( """Test that the group lasso initialize_solver goes through""" X, y, model, params, rate, mask = group_sparse_poisson_glm_model_instantiation model.set_params( + regularizer_strength=1., regularizer=nmo.regularizer.GroupLasso(mask=mask), solver_name="ProximalGradient", ) @@ -3210,13 +3269,13 @@ def test_feature_mask_compatibility_fit_tree( [ ( nmo.regularizer.UnRegularized(), - 0.001, + None, "LBFGS", {"stepsize": 0.1, "tol": 10**-14}, ), ( nmo.regularizer.UnRegularized(), - 1.0, + None, "GradientDescent", {"tol": 10**-14}, ), @@ -3262,7 +3321,7 @@ def test_masked_fit_vs_loop( ): jax.config.update("jax_enable_x64", True) if isinstance(mask, dict): - X, y, model, true_params, firing_rate = poisson_population_GLM_model_pytree + X, y, _, true_params, firing_rate = poisson_population_GLM_model_pytree def map_neu(k, coef_): key_ind = {"input_1": [0, 1, 2], "input_2": [3, 4]} @@ -3275,7 +3334,7 @@ def map_neu(k, coef_): return ind_array, coef_stack else: - X, y, model, true_params, firing_rate = poisson_population_GLM_model + X, y, _, true_params, firing_rate = poisson_population_GLM_model def map_neu(k, coef_): ind_array = np.where(mask[:, k])[0] @@ -3284,11 +3343,14 @@ def map_neu(k, coef_): mask_bool = jax.tree_util.tree_map(lambda x: np.asarray(x.T, dtype=bool), mask) # fit pop glm - model.feature_mask = mask - model.regularizer = regularizer - model.regularizer_strength = regularizer_strength - model.solver_name = solver_name - model.solver_kwargs = solver_kwargs + kwargs = dict( + feature_mask=mask, + regularizer=regularizer, + regularizer_strength=regularizer_strength, + solver_name=solver_name, + solver_kwargs=solver_kwargs, + ) + model = nmo.glm.PopulationGLM(**kwargs) model.fit(X, y) coef_vectorized = np.vstack(jax.tree_util.tree_leaves(model.coef_)) @@ -3331,12 +3393,56 @@ def test_waning_solver_reg_str(self, reg): model = nmo.glm.GLM(regularizer=reg, regularizer_strength=1.0) # reset to unregularized - model.regularizer = "UnRegularized" + model.set_params(regularizer="UnRegularized", regularizer_strength=None) with pytest.warns(UserWarning): nmo.glm.GLM(regularizer=reg) @pytest.mark.parametrize("reg", ["Ridge", "Lasso", "GroupLasso"]) def test_reg_strength_reset(self, reg): - model = nmo.glm.GLM(regularizer=reg, regularizer_strength=1.0) - model.regularizer = "UnRegularized" - assert model.regularizer_strength is None + model = nmo.glm.PopulationGLM(regularizer=reg, regularizer_strength=1.0) + with pytest.warns(UserWarning, match="Unused parameter `regularizer_strength` for UnRegularized GLM"): + model.regularizer = "UnRegularized" + model.regularizer_strength = None + with pytest.warns(UserWarning, match="Caution: regularizer strength has not been set"): + model.regularizer = "Ridge" + + @pytest.mark.parametrize( + "params, warns", + [ + # set regularizer + ({"regularizer": "Ridge"}, pytest.warns(UserWarning, match="Caution: regularizer strength has not been set")), + ({"regularizer": "Lasso"}, pytest.warns(UserWarning, match="Caution: regularizer strength has not been set")), + ({"regularizer": "GroupLasso"}, pytest.warns(UserWarning, match="Caution: regularizer strength has not been set")), + ({"regularizer": "UnRegularized"}, does_not_raise()), + # set both None or number + ({"regularizer": "Ridge", "regularizer_strength": None}, pytest.warns(UserWarning, match="Caution: regularizer strength has not been set")), + ({"regularizer": "Ridge", "regularizer_strength": 1.}, does_not_raise()), + ({"regularizer": "Lasso", "regularizer_strength": None}, pytest.warns(UserWarning, match="Caution: regularizer strength has not been set")), + ({"regularizer": "Lasso", "regularizer_strength": 1.}, does_not_raise()), + ({"regularizer": "GroupLasso", "regularizer_strength": None}, pytest.warns(UserWarning, match="Caution: regularizer strength has not been set")), + ({"regularizer": "GroupLasso", "regularizer_strength": 1.}, does_not_raise()), + ({"regularizer": "UnRegularized", "regularizer_strength": None}, does_not_raise()), + ({"regularizer": "UnRegularized", "regularizer_strength": 1.}, pytest.warns(UserWarning, match="Unused parameter `regularizer_strength` for UnRegularized GLM")), + # set regularizer str only + ({"regularizer_strength": 1.}, pytest.warns(UserWarning, match="Unused parameter `regularizer_strength` for UnRegularized GLM")), + ({"regularizer_strength": None},does_not_raise()), + ] + ) + def test_reg_set_params(self, params, warns): + model = nmo.glm.PopulationGLM() + with warns: + model.set_params(**params) + + @pytest.mark.parametrize( + "params, warns", + [ + # set regularizer str only + ({"regularizer_strength": 1.}, does_not_raise()), + ({"regularizer_strength": None},pytest.warns(UserWarning, match="Caution: regularizer strength has not been set")), + ] + ) + @pytest.mark.parametrize("reg", ["Ridge", "Lasso", "GroupLasso"]) + def test_reg_set_params_reg_str_only(self, params, warns, reg): + model = nmo.glm.PopulationGLM(regularizer=reg, regularizer_strength=1) + with warns: + model.set_params(**params) \ No newline at end of file diff --git a/tests/test_observation_models.py b/tests/test_observation_models.py index ffc27c46..25262035 100644 --- a/tests/test_observation_models.py +++ b/tests/test_observation_models.py @@ -1,3 +1,4 @@ +import warnings from contextlib import nullcontext as does_not_raise import jax @@ -501,7 +502,9 @@ def test_pseudo_r2_vs_statsmodels(self, gammaGLM_model_instantiation): X, y, model, _, firing_rate = gammaGLM_model_instantiation # statsmodels mcfadden - mdl = sm.GLM(y, sm.add_constant(X), family=sm.families.Gamma()).fit() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="The InversePower link function does") + mdl = sm.GLM(y, sm.add_constant(X), family=sm.families.Gamma()).fit() pr2_sms = mdl.pseudo_rsquared("mcf") # set params diff --git a/tests/test_proximal_operator.py b/tests/test_proximal_operator.py index 59d162bc..a6a65bfb 100644 --- a/tests/test_proximal_operator.py +++ b/tests/test_proximal_operator.py @@ -121,6 +121,7 @@ def test_prox_operator_shrinks_only_masked(example_data_prox_operator): def test_prox_operator_shrinks_only_masked_multineuron(example_data_prox_operator_multineuron): params, _, mask, _ = example_data_prox_operator_multineuron + mask = mask.astype(float) mask = mask.at[:, 1].set(jnp.zeros(2)) params_new = prox_group_lasso(params, 0.05, mask) assert jnp.all(params_new[0][1] == params[0][1]) diff --git a/tests/test_regularizer.py b/tests/test_regularizer.py index 32565d07..5abba876 100644 --- a/tests/test_regularizer.py +++ b/tests/test_regularizer.py @@ -1,4 +1,5 @@ import copy +import warnings import jax import jax.numpy as jnp @@ -6,6 +7,7 @@ import pytest import statsmodels.api as sm from sklearn.linear_model import GammaRegressor, PoissonRegressor +from statsmodels.tools.sm_exceptions import DomainWarning import nemos as nmo @@ -218,9 +220,9 @@ def test_regularizer_strength_none(self): assert model.regularizer_strength == 1.0 - # assert change back to unregularized is none - model.regularizer = regularizer - assert model.regularizer_strength is None + with pytest.warns(UserWarning): + model.regularizer = regularizer + assert model.regularizer_strength == 1. def test_get_params(self): """Test get_params() returns expected values.""" @@ -275,7 +277,7 @@ def test_run_solver(self, solver_name, poissonGLM_model_instantiation): X, y, model, true_params, firing_rate = poissonGLM_model_instantiation # set regularizer and solver name - model.regularizer = self.cls() + model.set_params(regularizer=self.cls()) model.solver_name = solver_name model.instantiate_solver() model.solver_run((true_params[0] * 0.0, true_params[1]), X, y) @@ -290,7 +292,7 @@ def test_run_solver_tree(self, solver_name, poissonGLM_model_instantiation_pytre X, y, model, true_params, firing_rate = poissonGLM_model_instantiation_pytree # set regularizer and solver name - model.regularizer = self.cls() + model.set_params(regularizer=self.cls()) model.solver_name = solver_name model.instantiate_solver() model.solver_run( @@ -307,7 +309,7 @@ def test_solver_output_match(self, poissonGLM_model_instantiation, solver_name): # set precision to float64 for accurate matching of the results model.data_type = jnp.float64 # set model params - model.regularizer = self.cls() + model.set_params(regularizer=self.cls()) model.solver_name = solver_name model.solver_kwargs = {"tol": 10**-12} model.instantiate_solver() @@ -338,7 +340,7 @@ def test_solver_match_sklearn(self, poissonGLM_model_instantiation, solver_name) X, y, model, true_params, firing_rate = poissonGLM_model_instantiation # set precision to float64 for accurate matching of the results model.data_type = jnp.float64 - model.regularizer = self.cls() + model.set_params(regularizer=self.cls()) model.solver_name = solver_name model.solver_kwargs = {"tol": 10**-12} model.instantiate_solver() @@ -363,7 +365,7 @@ def test_solver_match_sklearn_gamma( # set precision to float64 for accurate matching of the results model.data_type = jnp.float64 model.observation_model.inverse_link_function = jnp.exp - model.regularizer = self.cls() + model.set_params(regularizer=self.cls()) model.solver_name = solver_name model.solver_kwargs = {"tol": 10**-12} model.instantiate_solver() @@ -396,17 +398,18 @@ def test_solver_match_statsmodels_gamma( # set precision to float64 for accurate matching of the results model.data_type = jnp.float64 model.observation_model.inverse_link_function = inv_link_jax - model.regularizer = self.cls() + model.set_params(regularizer=self.cls()) model.solver_name = solver_name model.solver_kwargs = {"tol": 10**-13} model.instantiate_solver() weights_bfgs, intercepts_bfgs = model.solver_run( model._initialize_parameters(X, y), X, y )[0] - - model_sm = sm.GLM( - endog=y, exog=sm.add_constant(X), family=sm.families.Gamma(link=link_sm) - ) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="The InversePower link function does ") + model_sm = sm.GLM( + endog=y, exog=sm.add_constant(X), family=sm.families.Gamma(link=link_sm) + ) res_sm = model_sm.fit(cnvrg_tol=10**-12) @@ -429,7 +432,7 @@ def test_solver_match_statsmodels_gamma( ) def test_solver_combination(self, solver_name, poissonGLM_model_instantiation): X, y, model, true_params, firing_rate = poissonGLM_model_instantiation - model.regularizer = self.cls() + model.set_params(regularizer=self.cls()) model.solver_name = solver_name model.fit(X, y) @@ -465,9 +468,9 @@ def test_init_solver_name(self, solver_name): with pytest.raises( ValueError, match=f"The solver: {solver_name} is not allowed for " ): - nmo.glm.GLM(regularizer=self.cls(), solver_name=solver_name) + nmo.glm.GLM(regularizer=self.cls(), solver_name=solver_name, regularizer_strength=1.) else: - nmo.glm.GLM(regularizer=self.cls(), solver_name=solver_name) + nmo.glm.GLM(regularizer=self.cls(), solver_name=solver_name, regularizer_strength=1.) @pytest.mark.parametrize( "solver_name", @@ -494,7 +497,7 @@ def test_set_solver_name_allowed(self, solver_name): "ProxSVRG", ] regularizer = self.cls() - model = nmo.glm.GLM(regularizer=regularizer) + model = nmo.glm.GLM(regularizer=regularizer, regularizer_strength=1.) raise_exception = solver_name not in acceptable_solvers if raise_exception: with pytest.raises( @@ -518,12 +521,14 @@ def test_init_solver_kwargs(self, solver_name, solver_kwargs): regularizer=regularizer, solver_name=solver_name, solver_kwargs=solver_kwargs, + regularizer_strength=1. ) else: nmo.glm.GLM( regularizer=regularizer, solver_name=solver_name, solver_kwargs=solver_kwargs, + regularizer_strength=1. ) def test_regularizer_strength_none(self): @@ -535,13 +540,13 @@ def test_regularizer_strength_none(self): assert model.regularizer_strength == 1.0 - # if changed to regularized, should go to None - model.regularizer = "UnRegularized" - assert model.regularizer_strength is None + with pytest.warns(UserWarning): + # if changed to regularized, is kept to 1. + model.regularizer = "UnRegularized" + assert model.regularizer_strength == 1.0 # if changed back, should warn and set to 1.0 - with pytest.warns(UserWarning): - model.regularizer = "Ridge" + model.regularizer = "Ridge" assert model.regularizer_strength == 1.0 @@ -556,7 +561,7 @@ def test_loss_is_callable(self, loss): """Test Ridge callable loss.""" raise_exception = not callable(loss) regularizer = self.cls() - model = nmo.glm.GLM(regularizer=regularizer) + model = nmo.glm.GLM(regularizer=regularizer, regularizer_strength=1.) model._predict_and_compute_loss = loss if raise_exception: with pytest.raises(TypeError, match="The `loss` must be a Callable"): @@ -574,7 +579,7 @@ def test_run_solver(self, solver_name, poissonGLM_model_instantiation): X, y, model, true_params, firing_rate = poissonGLM_model_instantiation # set regularizer and solver name - model.regularizer = self.cls() + model.set_params(regularizer=self.cls(), regularizer_strength=1.) model.solver_name = solver_name runner = model.instantiate_solver().solver_run runner((true_params[0] * 0.0, true_params[1]), X, y) @@ -589,7 +594,7 @@ def test_run_solver_tree(self, solver_name, poissonGLM_model_instantiation_pytre X, y, model, true_params, firing_rate = poissonGLM_model_instantiation_pytree # set regularizer and solver name - model.regularizer = self.cls() + model.set_params(regularizer=self.cls(), regularizer_strength=1.) model.solver_name = solver_name runner = model.instantiate_solver().solver_run runner( @@ -607,7 +612,7 @@ def test_solver_output_match(self, poissonGLM_model_instantiation, solver_name): model.data_type = jnp.float64 # set model params - model.regularizer = self.cls() + model.set_params(regularizer=self.cls(), regularizer_strength=1.) model.solver_name = solver_name model.solver_kwargs = {"tol": 10**-12} @@ -638,7 +643,7 @@ def test_solver_match_sklearn(self, poissonGLM_model_instantiation): X, y, model, true_params, firing_rate = poissonGLM_model_instantiation # set precision to float64 for accurate matching of the results model.data_type = jnp.float64 - model.regularizer = self.cls() + model.set_params(regularizer=self.cls(), regularizer_strength=1.) model.solver_kwargs = {"tol": 10**-12} model.solver_name = "BFGS" @@ -665,7 +670,7 @@ def test_solver_match_sklearn_gamma(self, gammaGLM_model_instantiation): # set precision to float64 for accurate matching of the results model.data_type = jnp.float64 model.observation_model.inverse_link_function = jnp.exp - model.regularizer = self.cls() + model.set_params(regularizer=self.cls(), regularizer_strength=1.) model.solver_kwargs = {"tol": 10**-12} model.regularizer_strength = 0.1 model.solver_name = "BFGS" @@ -697,7 +702,7 @@ def test_solver_match_sklearn_gamma(self, gammaGLM_model_instantiation): ) def test_solver_combination(self, solver_name, poissonGLM_model_instantiation): X, y, model, true_params, firing_rate = poissonGLM_model_instantiation - model.regularizer = self.cls() + model.set_params(regularizer=self.cls(), regularizer_strength=1.) model.solver_name = solver_name model.fit(X, y) @@ -728,9 +733,9 @@ def test_init_solver_name(self, solver_name): with pytest.raises( ValueError, match=f"The solver: {solver_name} is not allowed for " ): - nmo.glm.GLM(regularizer=self.cls(), solver_name=solver_name) + nmo.glm.GLM(regularizer=self.cls(), solver_name=solver_name, regularizer_strength=1) else: - nmo.glm.GLM(regularizer=self.cls(), solver_name=solver_name) + nmo.glm.GLM(regularizer=self.cls(), solver_name=solver_name, regularizer_strength=1) @pytest.mark.parametrize( "solver_name", @@ -751,7 +756,7 @@ def test_set_solver_name_allowed(self, solver_name): "ProxSVRG", ] regularizer = self.cls() - model = nmo.glm.GLM(regularizer=regularizer) + model = nmo.glm.GLM(regularizer=regularizer, regularizer_strength=1) raise_exception = solver_name not in acceptable_solvers if raise_exception: with pytest.raises( @@ -775,25 +780,27 @@ def test_init_solver_kwargs(self, solver_kwargs, solver_name): regularizer=regularizer, solver_name=solver_name, solver_kwargs=solver_kwargs, + regularizer_strength=1. ) else: nmo.glm.GLM( regularizer=regularizer, solver_name=solver_name, solver_kwargs=solver_kwargs, + regularizer_strength=1. ) def test_regularizer_strength_none(self): """Assert regularizer strength handled appropriately.""" # if no strength given, should warn and set to 1.0 + regularizer = self.cls() with pytest.warns(UserWarning): - regularizer = self.cls() model = nmo.glm.GLM(regularizer=regularizer) assert model.regularizer_strength == 1.0 # if changed to regularized, should go to None - model.regularizer = "UnRegularized" + model.set_params(regularizer="UnRegularized", regularizer_strength=None) assert model.regularizer_strength is None # if changed back, should warn and set to 1.0 @@ -813,7 +820,7 @@ def test_loss_callable(self, loss): """Test that the loss function is a callable""" raise_exception = not callable(loss) regularizer = self.cls() - model = nmo.glm.GLM(regularizer=regularizer) + model = nmo.glm.GLM(regularizer=regularizer, regularizer_strength=1) model._predict_and_compute_loss = loss if raise_exception: with pytest.raises(TypeError, match="The `loss` must be a Callable"): @@ -827,7 +834,7 @@ def test_run_solver(self, solver_name, poissonGLM_model_instantiation): X, y, model, true_params, firing_rate = poissonGLM_model_instantiation - model.regularizer = self.cls() + model.set_params(regularizer=self.cls(), regularizer_strength=1) model.solver_name = solver_name runner = model.instantiate_solver().solver_run runner((true_params[0] * 0.0, true_params[1]), X, y) @@ -839,7 +846,7 @@ def test_run_solver_tree(self, solver_name, poissonGLM_model_instantiation_pytre X, y, model, true_params, firing_rate = poissonGLM_model_instantiation_pytree # set regularizer and solver name - model.regularizer = self.cls() + model.set_params(regularizer=self.cls(), regularizer_strength=1) model.solver_name = solver_name runner = model.instantiate_solver().solver_run runner( @@ -857,7 +864,7 @@ def test_solver_match_statsmodels( X, y, model, true_params, firing_rate = poissonGLM_model_instantiation # set precision to float64 for accurate matching of the results model.data_type = jnp.float64 - model.regularizer = self.cls() + model.set_params(regularizer=self.cls(), regularizer_strength=1) model.solver_name = solver_name model.solver_kwargs = {"tol": 10**-12} @@ -885,7 +892,7 @@ def test_solver_match_statsmodels( def test_lasso_pytree(self, poissonGLM_model_instantiation_pytree): """Check pytree X can be fit.""" X, y, model, true_params, firing_rate = poissonGLM_model_instantiation_pytree - model.regularizer = nmo.regularizer.Lasso() + model.set_params(regularizer=nmo.regularizer.Lasso(), regularizer_strength=1.) model.solver_name = "ProximalGradient" model.fit(X, y) @@ -903,10 +910,9 @@ def test_lasso_pytree_match( X, _, model, _, _ = poissonGLM_model_instantiation_pytree X_array, y, model_array, _, _ = poissonGLM_model_instantiation - model.regularizer_strength = reg_str - model_array.regularizer_strength = reg_str - model.regularizer = nmo.regularizer.Lasso() - model_array.regularizer = nmo.regularizer.Lasso() + + model.set_params(regularizer=nmo.regularizer.Lasso(), regularizer_strength=reg_str) + model_array.set_params(regularizer=nmo.regularizer.Lasso(), regularizer_strength=reg_str) model.solver_name = solver_name model_array.solver_name = solver_name model.fit(X, y) @@ -918,7 +924,7 @@ def test_lasso_pytree_match( @pytest.mark.parametrize("solver_name", ["ProximalGradient", "ProxSVRG"]) def test_solver_combination(self, solver_name, poissonGLM_model_instantiation): X, y, model, true_params, firing_rate = poissonGLM_model_instantiation - model.regularizer = self.cls() + model.set_params(regularizer=self.cls(), regularizer_strength=1.) model.solver_name = solver_name model.fit(X, y) @@ -956,9 +962,9 @@ def test_init_solver_name(self, solver_name): with pytest.raises( ValueError, match=f"The solver: {solver_name} is not allowed for " ): - nmo.glm.GLM(regularizer=self.cls(mask=mask), solver_name=solver_name) + nmo.glm.GLM(regularizer=self.cls(mask=mask), solver_name=solver_name, regularizer_strength=1) else: - nmo.glm.GLM(regularizer=self.cls(mask=mask), solver_name=solver_name) + nmo.glm.GLM(regularizer=self.cls(mask=mask), solver_name=solver_name, regularizer_strength=1) @pytest.mark.parametrize( "solver_name", @@ -985,7 +991,7 @@ def test_set_solver_name_allowed(self, solver_name): mask = jnp.asarray(mask) regularizer = self.cls(mask=mask) raise_exception = solver_name not in acceptable_solvers - model = nmo.glm.GLM(regularizer=regularizer) + model = nmo.glm.GLM(regularizer=regularizer, regularizer_strength=1) if raise_exception: with pytest.raises( ValueError, match=f"The solver: {solver_name} is not allowed for " @@ -1016,26 +1022,27 @@ def test_init_solver_kwargs(self, solver_name, solver_kwargs): regularizer=regularizer, solver_name=solver_name, solver_kwargs=solver_kwargs, + regularizer_strength=1. ) else: nmo.glm.GLM( regularizer=regularizer, solver_name=solver_name, solver_kwargs=solver_kwargs, + regularizer_strength=1. ) def test_regularizer_strength_none(self): """Assert regularizer strength handled appropriately.""" # if no strength given, should warn and set to 1.0 + regularizer = self.cls() with pytest.warns(UserWarning): - regularizer = self.cls() model = nmo.glm.GLM(regularizer=regularizer) assert model.regularizer_strength == 1.0 # if changed to regularized, should go to None - model.regularizer = "UnRegularized" - assert model.regularizer_strength is None + model.set_params(regularizer="UnRegularized", regularizer_strength=None) # if changed back, should warn and set to 1.0 with pytest.warns(UserWarning): @@ -1061,7 +1068,7 @@ def test_loss_callable(self, loss): mask = jnp.asarray(mask) regularizer = self.cls(mask=mask) - model = nmo.glm.GLM(regularizer=regularizer) + model = nmo.glm.GLM(regularizer=regularizer, regularizer_strength=1.) model._predict_and_compute_loss = loss if raise_exception: @@ -1082,7 +1089,7 @@ def test_run_solver(self, solver_name, poissonGLM_model_instantiation): mask[1, 2:] = 1 mask = jnp.asarray(mask) - model.regularizer = self.cls(mask=mask) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) model.solver_name = solver_name model.instantiate_solver() @@ -1100,7 +1107,7 @@ def test_init_solver(self, solver_name, poissonGLM_model_instantiation): mask[1, 2:] = 1 mask = jnp.asarray(mask) - model.regularizer = self.cls(mask=mask) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) model.solver_name = solver_name model.instantiate_solver() @@ -1126,7 +1133,7 @@ def test_update_solver(self, solver_name, poissonGLM_model_instantiation): mask[1, 2:] = 1 mask = jnp.asarray(mask) - model.regularizer = self.cls(mask=mask) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) model.solver_name = solver_name model.instantiate_solver() @@ -1186,9 +1193,9 @@ def test_mask_validity_groups( with pytest.raises( ValueError, match="Incorrect group assignment. " "Some of the features" ): - model.regularizer = self.cls(mask=mask) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) else: - model.regularizer = self.cls(mask=mask) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) @pytest.mark.parametrize("set_entry", [0, 1, -1, 2, 2.5]) def test_mask_validity_entries(self, set_entry, poissonGLM_model_instantiation): @@ -1206,9 +1213,9 @@ def test_mask_validity_entries(self, set_entry, poissonGLM_model_instantiation): if raise_exception: with pytest.raises(ValueError, match="Mask elements be 0s and 1s"): - model.regularizer = self.cls(mask=mask) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) else: - model.regularizer = self.cls(mask=mask) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) @pytest.mark.parametrize("n_dim", [0, 1, 2, 3]) def test_mask_dimension_1(self, n_dim, poissonGLM_model_instantiation): @@ -1235,9 +1242,9 @@ def test_mask_dimension_1(self, n_dim, poissonGLM_model_instantiation): if raise_exception: with pytest.raises(ValueError, match="`mask` must be 2-dimensional"): - model.regularizer = self.cls(mask=mask) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) else: - model.regularizer = self.cls(mask=mask) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) @pytest.mark.parametrize("n_groups", [0, 1, 2]) def test_mask_n_groups(self, n_groups, poissonGLM_model_instantiation): @@ -1256,9 +1263,9 @@ def test_mask_n_groups(self, n_groups, poissonGLM_model_instantiation): if raise_exception: with pytest.raises(ValueError, match=r"Empty mask provided! Mask has "): - model.regularizer = self.cls(mask=mask) + model.set_params(regularizer = self.cls(mask=mask), regularizer_strength=1.) else: - model.regularizer = self.cls(mask=mask) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) def test_group_sparsity_enforcement( self, group_sparse_poisson_glm_model_instantiation @@ -1278,7 +1285,7 @@ def test_group_sparsity_enforcement( mask[1, ~zeros_true] = 1 mask = jnp.asarray(mask, dtype=jnp.float32) - model.regularizer = self.cls(mask=mask) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) model.solver_name = "ProximalGradient" runner = model.instantiate_solver().solver_run @@ -1415,14 +1422,15 @@ def test_mask_none(self, poissonGLM_model_instantiation): X, y, model, true_params, firing_rate = poissonGLM_model_instantiation with pytest.warns(UserWarning): - model.regularizer = self.cls() + model.regularizer = self.cls(mask=np.ones((1, X.shape[1])).astype(float)) model.solver_name = "ProximalGradient" model.fit(X, y) @pytest.mark.parametrize("solver_name", ["ProximalGradient", "ProxSVRG"]) def test_solver_combination(self, solver_name, poissonGLM_model_instantiation): X, y, model, true_params, firing_rate = poissonGLM_model_instantiation - model.regularizer = self.cls() + model.set_params(regularizer=self.cls(mask=np.ones((1, X.shape[1])).astype(float)), + regularizer_strength=None if self.cls==nmo.regularizer.UnRegularized else 1.) model.solver_name = solver_name model.fit(X, y) diff --git a/tests/test_solvers.py b/tests/test_solvers.py index 72970397..745e6cf9 100644 --- a/tests/test_solvers.py +++ b/tests/test_solvers.py @@ -128,7 +128,7 @@ def test_svrg_glm_instantiate_solver(regularizer_name, solver_class, mask): if mask is not None: kwargs["mask"] = mask - glm = nmo.glm.GLM(regularizer=regularizer_name, solver_name=solver_name) + glm = nmo.glm.GLM(regularizer=regularizer_name, solver_name=solver_name, regularizer_strength=None if regularizer_name == "UnRegularized" else 1,) glm.instantiate_solver() solver = inspect.getclosurevars(glm._solver_run).nonlocals["solver"] @@ -161,6 +161,7 @@ def test_svrg_glm_passes_solver_kwargs(regularizer_name, solver_name, mask, glm_ regularizer=regularizer_name, solver_name=solver_name, solver_kwargs=solver_kwargs, + regularizer_strength=None if regularizer_name == "UnRegularized" else 1, **kwargs, ) glm.instantiate_solver() @@ -177,9 +178,9 @@ def test_svrg_glm_passes_solver_kwargs(regularizer_name, solver_name, mask, glm_ ( "GroupLasso", ProxSVRG, - np.array([[0], [0], [1]]), + np.array([[0.], [0.], [1.]]), ), - ("GroupLasso", ProxSVRG, None), + ("GroupLasso", ProxSVRG, np.array([[1.], [0.], [0.]])), ("Ridge", SVRG, None), ("UnRegularized", SVRG, None), ], @@ -196,15 +197,22 @@ def test_svrg_glm_initialize_state( if glm_class == nmo.glm.PopulationGLM: y = np.expand_dims(y, 1) + reg_cls = getattr(nmo.regularizer, regularizer_name) + if regularizer_name == "GroupLasso": + reg = reg_cls(mask=mask) + else: + reg = reg_cls() + # only pass mask if it's not None kwargs = {} if mask is not None and glm_class == nmo.glm.PopulationGLM: kwargs["feature_mask"] = mask glm = glm_class( - regularizer=regularizer_name, + regularizer=reg, solver_name=solver_class.__name__, observation_model=nmo.observation_models.PoissonObservations(jax.nn.softplus), + regularizer_strength=None if regularizer_name == "UnRegularized" else 1, **kwargs, ) @@ -225,7 +233,7 @@ def test_svrg_glm_initialize_state( ( "GroupLasso", ProxSVRG, - np.array([[0], [0], [1]]), + np.array([[0.], [0.], [1.]]), ), ("Ridge", SVRG, None), ("UnRegularized", SVRG, None), @@ -244,13 +252,20 @@ def test_svrg_glm_update( # only pass mask if it's not None kwargs = {} - if mask is not None and glm_class == nmo.glm.PopulationGLM: + if glm_class == nmo.glm.PopulationGLM: kwargs["feature_mask"] = mask + reg_cls = getattr(nmo.regularizer, regularizer_name) + if regularizer_name == "GroupLasso": + reg = reg_cls(mask=mask) + else: + reg = reg_cls() + glm = glm_class( - regularizer=regularizer_name, + regularizer=reg, solver_name=solver_class.__name__, observation_model=nmo.observation_models.PoissonObservations(jax.nn.softplus), + regularizer_strength=None if regularizer_name == "UnRegularized" else 1, **kwargs, ) @@ -276,15 +291,15 @@ def test_svrg_glm_update( ( "GroupLasso", "ProxSVRG", - np.array([[0, 1, 0], [0, 0, 1]]).reshape(2, -1).astype(float), + np.array([[0, 1, 0, 1, 1], [1, 0, 1, 0, 0]]).astype(float), ), - ("GroupLasso", "ProxSVRG", None), + ("GroupLasso", "ProxSVRG", np.array([[1, 1, 1, 1, 1]]).astype(float)), ( "GroupLasso", "ProximalGradient", - np.array([[0, 1, 0], [0, 0, 1]]).reshape(2, -1).astype(float), + np.array([[0, 1, 0, 1, 1], [1, 0, 1, 0, 0]]).astype(float), ), - ("GroupLasso", "ProximalGradient", None), + ("GroupLasso", "ProximalGradient", np.array([[1, 1, 1, 1, 1]]).astype(float)), ("Ridge", "SVRG", None), ("UnRegularized", "SVRG", None), ], @@ -314,15 +329,23 @@ def test_svrg_glm_fit( } # only pass mask if it's not None + reg_cls = getattr(nmo.regularizer, regularizer_name) + if regularizer_name == "GroupLasso": + reg = reg_cls(mask=mask) + else: + reg = reg_cls() + kwargs = {} - if mask is not None: - kwargs["feature_mask"] = mask + if glm_class == nmo.glm.PopulationGLM: + kwargs["feature_mask"] = np.ones((X.shape[1], 1)) glm = glm_class( - regularizer=regularizer_name, + regularizer=reg, solver_name=solver_name, observation_model=nmo.observation_models.PoissonObservations(jax.nn.softplus), solver_kwargs=solver_kwargs, + regularizer_strength=None if regularizer_name == "UnRegularized" else 1, + **kwargs ) if isinstance(glm, nmo.glm.PopulationGLM): @@ -339,7 +362,7 @@ def test_svrg_glm_fit( "regularizer_name, solver_class, mask", [ ("Lasso", ProxSVRG, None), - ("GroupLasso", ProxSVRG, np.array([0, 1, 0]).reshape(1, -1).astype(float)), + ("GroupLasso", ProxSVRG, np.array([0, 1, 0]).reshape(-1, 1).astype(float)), ("Ridge", SVRG, None), ("UnRegularized", SVRG, None), ], @@ -356,16 +379,23 @@ def test_svrg_glm_update_needs_full_grad_at_reference_point( y = np.expand_dims(y, 1) # only pass mask if it's not None - kwargs = {} - if mask is not None and glm_class == nmo.glm.PopulationGLM: - kwargs["feature_mask"] = mask - - glm = glm_class( - regularizer=regularizer_name, + reg_cls = getattr(nmo.regularizer, regularizer_name) + if regularizer_name == "GroupLasso": + reg = reg_cls(mask=mask) + else: + reg = reg_cls() + kwargs = dict( + regularizer=reg, solver_name=solver_class.__name__, observation_model=nmo.observation_models.PoissonObservations(jax.nn.softplus), + regularizer_strength=None if regularizer_name == "UnRegularized" else 0.1, ) + if mask is not None and glm_class == nmo.glm.PopulationGLM: + kwargs["feature_mask"] = np.array([0, 1, 0]).reshape(-1, 1).astype(float) + + glm = glm_class(**kwargs) + with pytest.raises( ValueError, match=r"Full gradient at the anchor point \(state\.full_grad_at_reference_point\) has to be set", diff --git a/tests/test_tree_utils.py b/tests/test_tree_utils.py index f4e10431..33c58850 100644 --- a/tests/test_tree_utils.py +++ b/tests/test_tree_utils.py @@ -1,5 +1,5 @@ -import numpy as np import jax.numpy as jnp +import numpy as np import pytest from nemos import tree_utils diff --git a/tests/test_type_casting.py b/tests/test_type_casting.py index bed5b1f0..b7cbbe48 100644 --- a/tests/test_type_casting.py +++ b/tests/test_type_casting.py @@ -370,7 +370,7 @@ def func(*x): ( [ nap.Tsd(t=np.arange(10), d=np.arange(10)), - nap.Tsd(t=np.arange(1), d=np.arange(1)), + nap.Tsd(t=np.arange(1), d=np.arange(1), time_support=nap.IntervalSet(0, 10)), nap.Tsd(t=np.arange(10), d=np.arange(10)), ], pytest.raises( diff --git a/tests/test_utils.py b/tests/test_utils.py index c3babd71..abe46314 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,3 +1,4 @@ +import warnings from contextlib import nullcontext as does_not_raise import jax @@ -107,7 +108,10 @@ def test_conv_type(self, iterable, predictor_causality): with pytest.raises(ValueError, match="predictor_causality must be one of"): utils.nan_pad(iterable, 3, predictor_causality) else: - utils.nan_pad(iterable, 3, predictor_causality) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning, + message="With acausal filter, pad_size should probably be even") + utils.nan_pad(iterable, 3, predictor_causality) @pytest.mark.parametrize("iterable", [[np.zeros([2, 4, 5]), np.zeros([2, 4, 6])]]) @pytest.mark.parametrize("pad_size", [0.1, -1, 0, 1, 2, 3, 5, 6]) @@ -159,7 +163,7 @@ def test_padding_nan_anti_causal(self, pad_size, iterable): ), "Size after padding doesn't match expectation. Should be T + window_size - 1." @pytest.mark.parametrize("iterable", [[np.zeros([2, 5, 4]), np.zeros([2, 6, 4])]]) - @pytest.mark.parametrize("pad_size", [-1, 0.2, 0, 1, 2, 3, 5, 6]) + @pytest.mark.parametrize("pad_size", [-1, 0.2, 0, 1, 3, 5]) def test_padding_nan_acausal(self, pad_size, iterable): raise_exception = (not isinstance(pad_size, int)) or (pad_size <= 0) if raise_exception: @@ -170,7 +174,10 @@ def test_padding_nan_acausal(self, pad_size, iterable): else: init_nan, end_nan = pad_size // 2, pad_size - pad_size // 2 - padded = utils.nan_pad(iterable, pad_size, "acausal") + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning, + message="With acausal filter, pad_size should probably be even") + padded = utils.nan_pad(iterable, pad_size, "acausal") for trial in padded: print(trial.shape, pad_size) assert all(np.isnan(trial[:init_nan]).all() for trial in padded), ( @@ -252,8 +259,11 @@ def test_nan_pad_conv_dtype(self, dtype, expectation): ], ) def test_axis_compatibility(self, pad_size, array, causality, axis, expectation): - with expectation: - utils.nan_pad(array, pad_size, causality, axis=axis) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning, + message="With acausal filter, pad_size should probably be even") + with expectation: + utils.nan_pad(array, pad_size, causality, axis=axis) @pytest.mark.parametrize("causality", ["causal", "acausal", "anti-causal"]) @pytest.mark.parametrize( @@ -273,8 +283,11 @@ def test_axis_compatibility(self, pad_size, array, causality, axis, expectation) ) @pytest.mark.parametrize("array", [jnp.zeros((10,)), np.zeros((10, 11))]) def test_pad_size_type(self, pad_size, array, causality, expectation): - with expectation: - utils.nan_pad(array, pad_size, causality, axis=0) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning, + message="With acausal filter, pad_size should probably be even") + with expectation: + utils.nan_pad(array, pad_size, causality, axis=0) @pytest.mark.parametrize( "causality, pad_size, expectation", diff --git a/tox.ini b/tox.ini index 277799b6..10f8f786 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,6 @@ [tox] isolated_build = True -envlist = py310, py311, py312 +envlist = py,fix [testenv] @@ -12,17 +12,26 @@ extras = dev package_cache = .tox/cache # Run both pytest and coverage since pytest was initialized with the --cov option in the pyproject.toml -# while black, isort and flake8 are also i commands = - black --check src + pytest --doctest-modules src/nemos/ + pytest --cov=nemos --cov-config=pyproject.toml --cov-report=xml + +[testenv:fix] +commands= + black src isort src --profile=black isort docs/how_to_guide --profile=black isort docs/background --profile=black isort docs/tutorials --profile=black - flake8 --config={toxinidir}/tox.ini src - pytest --doctest-modules src/nemos/ - pytest --cov=nemos --cov-config=pyproject.toml --cov-report=xml +[testenv:check] +commands= + black --check src + isort --check src --profile=black + isort --check docs/how_to_guide --profile=black + isort --check docs/background --profile=black + isort --check docs/tutorials --profile=black + flake8 --config={toxinidir}/tox.ini src [gh-actions] python =