diff --git a/src/nemos/basis.py b/src/nemos/basis.py index ef36a28b..a837f71b 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -6,7 +6,6 @@ import abc from typing import Generator, Tuple -import jax.numpy import numpy as np import scipy.linalg from numpy.typing import ArrayLike, NDArray @@ -52,13 +51,13 @@ def __init__(self, n_basis_funcs: int) -> None: self._check_n_basis_min() @abc.abstractmethod - def _evaluate(self, *xi: NDArray) -> NDArray: + def evaluate(self, *xi: NDArray) -> NDArray: """ - Evaluate the basis set at the given samples x1,...,xn using the subclass-specific "_evaluate" method. + Evaluate the basis set at the given samples x1,...,xn using the subclass-specific "evaluate" method. Parameters ---------- - *xi: (number of samples, ) + *xi: (n_samples,) The input samples xi[0],...,xi[n] . """ pass @@ -82,35 +81,30 @@ def _get_samples(*n_samples: int) -> Generator[NDArray, ...]: """ return (np.linspace(0, 1, n_samples[k]) for k in range(len(n_samples))) - def evaluate(self, *xi: ArrayLike) -> NDArray: - """ - Evaluate the basis set at the given samples x[0],...,x[n] using the subclass-specific "_evaluate" method. + def _check_evaluate_input(self, *xi: ArrayLike) -> Tuple[NDArray]: + """Check evaluate input. Parameters ---------- xi[0],...,xi[n] : The input samples, each with shape (number of samples, ). - Returns - ------- - : - The generated basis functions. - Raises ------ ValueError - If the time point number is inconsistent between inputs. - If the number of inputs doesn't match what the Basis object requires. - At least one of the samples is empty. - """ - # check that the input is array-like - if any( - not isinstance(x, (list, tuple, np.ndarray, jax.numpy.ndarray)) for x in xi - ): - raise TypeError("Input samples must be array-like!") - # convert to numpy.array of floats - xi = tuple(np.asarray(x, dtype=float) for x in xi) + """ + # check that the input is array-like (i.e., whether we can cast it to + # numeric arrays) + try: + # make sure array is at least 1d (so that we succeed when only + # passed a scalar) + xi = tuple(np.atleast_1d(np.asarray(x, dtype=float)) for x in xi) + except TypeError: + raise TypeError("Input samples must be array-like of floats!") # check for non-empty samples if self._has_zero_samples(tuple(len(x) for x in xi)): @@ -119,10 +113,7 @@ def evaluate(self, *xi: ArrayLike) -> NDArray: # checks on input and outputs self._check_samples_consistency(*xi) self._check_input_dimensionality(xi) - - eval_basis = self._evaluate(*xi) - - return eval_basis + return xi def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: """Evaluate the basis set on a grid of equi-spaced sample points. @@ -134,7 +125,8 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: Parameters ---------- n_samples[0],...,n_samples[n] - The number of samples in each axis of the grid. + The number of samples in each axis of the grid. The length of + n_samples must equal the number of combined bases. Returns ------- @@ -196,7 +188,7 @@ def _check_input_dimensionality(self, xi: Tuple) -> None: If the number of inputs doesn't match what the Basis object requires. """ if len(xi) != self._n_input_dimensionality: - raise ValueError( + raise TypeError( f"Input dimensionality mismatch. This basis evaluation requires {self._n_input_dimensionality} inputs, " f"{len(xi)} inputs provided instead." ) @@ -335,24 +327,27 @@ def __init__(self, basis1: Basis, basis2: Basis) -> None: def _check_n_basis_min(self) -> None: pass - def _evaluate(self, *xi: NDArray) -> NDArray: + def evaluate(self, *xi: NDArray) -> NDArray: """ Evaluate the basis at the input samples. Parameters ---------- - xi[0], ..., xi[n] : (number of samples, ) - Tuple of input samples. + xi[0], ..., xi[n] : (n_samples,) + Tuple of input samples, each with the same number of samples. The + number of input arrays must equal the number of combined bases. Returns ------- : The basis function evaluated at the samples, shape (n_samples, n_basis_funcs) + """ + xi = self._check_evaluate_input(*xi) return np.hstack( ( - self._basis1._evaluate(*xi[: self._basis1._n_input_dimensionality]), - self._basis2._evaluate(*xi[self._basis1._n_input_dimensionality :]), + self._basis1.evaluate(*xi[: self._basis1._n_input_dimensionality]), + self._basis2.evaluate(*xi[self._basis1._n_input_dimensionality :]), ) ) @@ -388,24 +383,26 @@ def __init__(self, basis1: Basis, basis2: Basis) -> None: def _check_n_basis_min(self) -> None: pass - def _evaluate(self, *xi: NDArray) -> NDArray: + def evaluate(self, *xi: NDArray) -> NDArray: """ Evaluate the basis at the input samples. Parameters ---------- - xi[0], ..., xi[n] : (number of samples, ) - Tuple of input samples. + xi[0], ..., xi[n] : (n_samples,) + Tuple of input samples, each with the same number of samples. The + number of input arrays must equal the number of combined bases. Returns ------- : The basis function evaluated at the samples, shape (n_samples, n_basis_funcs) """ + xi = self._check_evaluate_input(*xi) return np.array( row_wise_kron( - self._basis1._evaluate(*xi[: self._basis1._n_input_dimensionality]), - self._basis2._evaluate(*xi[self._basis1._n_input_dimensionality :]), + self._basis1.evaluate(*xi[: self._basis1._n_input_dimensionality]), + self._basis2.evaluate(*xi[self._basis1._n_input_dimensionality :]), transpose=False, ) ) @@ -448,7 +445,7 @@ def _generate_knots( Parameters ---------- - sample_pts : (number of samples, ) + sample_pts : (n_samples,) The sample points. perc_low The low percentile value, between [0,1). @@ -531,14 +528,13 @@ class MSplineBasis(SplineBasis): def __init__(self, n_basis_funcs: int, order: int = 2) -> None: super().__init__(n_basis_funcs, order) - def _evaluate(self, sample_pts: NDArray) -> NDArray: + def evaluate(self, sample_pts: NDArray) -> NDArray: """Generate basis functions with given spacing. Parameters ---------- sample_pts : - Spacing for basis functions, holding elements on the interval [min(sample_pts), - max(sample_pts)], shape (number of samples, ) + Spacing for basis functions, shape (n_samples,) Returns ------- @@ -546,6 +542,7 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: Evaluated spline basis functions, shape (n_samples, n_basis_funcs). """ + (sample_pts,) = self._check_evaluate_input(sample_pts) # add knots if not passed knot_locs = self._generate_knots( sample_pts, perc_low=0.0, perc_high=1.0, is_cyclic=False @@ -559,6 +556,25 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: axis=1, ) + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """Evaluate the M-spline basis set on a grid of equi-spaced sample points. + + Parameters + ---------- + n_samples : + The number of samples. + + Returns + ------- + X : + Array of shape (n_samples,) containing the equi-spaced sample + points where we've evaluated the basis. + basis_funcs : + Raised cosine basis functions, shape (n_samples, n_basis_funcs) + + """ + return super().evaluate_on_grid(n_samples) + class BSplineBasis(SplineBasis): """ @@ -589,30 +605,31 @@ class BSplineBasis(SplineBasis): def __init__(self, n_basis_funcs: int, order: int = 2): super().__init__(n_basis_funcs, order=order) - def _evaluate(self, sample_pts: NDArray) -> NDArray: + def evaluate(self, sample_pts: NDArray) -> NDArray: """ Evaluate the B-spline basis functions with given sample points. Parameters ---------- sample_pts : - The sample points at which the B-spline is evaluated. + The sample points at which the B-spline is evaluated, shape (n_samples,). Returns ------- - NDArray + basis_funcs : The basis function evaluated at the samples, shape (n_samples, n_basis_funcs) Raises ------ AssertionError - If the sample points are not within the B-spline knots range unless `outer_ok=True`. + If the sample points are not within the B-spline knots. Notes ----- The evaluation is performed by looping over each element and using `splev` from SciPy to compute the basis values. """ + (sample_pts,) = self._check_evaluate_input(sample_pts) # add knots knot_locs = self._generate_knots(sample_pts, 0.0, 1.0) @@ -623,6 +640,29 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: return basis_eval + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """Evaluate the B-spline basis set on a grid of equi-spaced sample points. + + Parameters + ---------- + n_samples : + The number of samples. + + Returns + ------- + X : + Array of shape (n_samples,) containing the equi-spaced sample + points where we've evaluated the basis. + basis_funcs : + Raised cosine basis functions, shape (n_samples, n_basis_funcs) + + Notes + ----- + The evaluation is performed by looping over each element and using `splev` from + SciPy to compute the basis values. + """ + return super().evaluate_on_grid(n_samples) + class CyclicBSplineBasis(SplineBasis): """ @@ -653,30 +693,27 @@ def __init__(self, n_basis_funcs: int, order: int = 2): f"order {self.order} specified instead!" ) - def _evaluate(self, sample_pts: NDArray) -> NDArray: - """ - Evaluate the B-spline basis functions with given sample points. + def evaluate(self, sample_pts: NDArray) -> NDArray: + """Evaluate the Cyclic B-spline basis functions with given sample points. Parameters ---------- sample_pts : - The sample points at which the B-spline is evaluated. Must be a tuple of length 1. + The sample points at which the cyclic B-spline is evaluated, shape + (n_samples,). Returns ------- - NDArray + basis_funcs : The basis function evaluated at the samples, shape (n_samples, n_basis_funcs) - Raises - ------ - AssertionError - If the sample points are not within the B-spline knots range unless `outer_ok=True`. - Notes ----- The evaluation is performed by looping over each element and using `splev` from SciPy to compute the basis values. + """ + (sample_pts,) = self._check_evaluate_input(sample_pts) knot_locs = self._generate_knots(sample_pts, 0.0, 1.0, is_cyclic=True) @@ -710,6 +747,29 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: return basis_eval + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """Evaluate the Cyclic B-spline basis set on a grid of equi-spaced sample points. + + Parameters + ---------- + n_samples : + The number of samples. + + Returns + ------- + X : + Array of shape (n_samples,) containing the equi-spaced sample + points where we've evaluated the basis. + basis_funcs : + Raised cosine basis functions, shape (n_samples, n_basis_funcs) + + Notes + ----- + The evaluation is performed by looping over each element and using `splev` from + SciPy to compute the basis values. + """ + return super().evaluate_on_grid(n_samples) + class RaisedCosineBasis(Basis, abc.ABC): def __init__(self, n_basis_funcs: int) -> None: @@ -724,21 +784,21 @@ def _transform_samples(self, sample_pts: NDArray) -> NDArray: Parameters ---------- sample_pts : - The sample points to be transformed, shape (number of samples, ). + The sample points to be transformed, shape (n_samples,). """ pass - def _evaluate(self, sample_pts: NDArray) -> NDArray: + def evaluate(self, sample_pts: NDArray) -> NDArray: """Generate basis functions with given samples. Parameters ---------- sample_pts : (number of samples,) - Spacing for basis functions, holding elements on interval [0, - 1). A good default is - ``nmo.sample_points.raised_cosine_log`` for log spacing (as used in - [2]_) or ``nmo.sample_points.raised_cosine_linear`` for linear - spacing. + Spacing for basis functions, holding elements on interval [0, 1). A + good default is ``nmo.sample_points.raised_cosine_log`` for log + spacing (as used in [2]_) or + ``nmo.sample_points.raised_cosine_linear`` for linear spacing. + Shape (n_samples,). Returns ------- @@ -749,7 +809,9 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: ------ ValueError If the sample provided do not lie in [0,1]. + """ + (sample_pts,) = self._check_evaluate_input(sample_pts) if any(sample_pts < 0) or any(sample_pts > 1): raise ValueError("Sample points for RaisedCosine basis must lie in [0,1]!") @@ -764,6 +826,25 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: return basis_funcs + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """Evaluate the basis set on a grid of equi-spaced sample points. + + Parameters + ---------- + n_samples : + The number of samples. + + Returns + ------- + X : + Array of shape (n_samples,) containing the equi-spaced sample + points where we've evaluated the basis. + basis_funcs : + Raised cosine basis functions, shape (n_samples, n_basis_funcs) + + """ + return super().evaluate_on_grid(n_samples) + class RaisedCosineBasisLinear(RaisedCosineBasis): """Linearly-spaced raised cosine basis functions used by Pillow et al. [2]_. @@ -794,7 +875,7 @@ def _transform_samples(self, sample_pts: NDArray) -> NDArray: Parameters ---------- sample_pts : - The sample points used for evaluating the splines, shape (number of samples, ) + The sample points used for evaluating the splines, shape (n_samples,) Returns ------- @@ -851,7 +932,7 @@ def _transform_samples(self, sample_pts: NDArray) -> NDArray: Parameters ---------- sample_pts : - The sample points used for evaluating the splines, shape (number of samples, ). + The sample points used for evaluating the splines, shape (n_samples,). Returns ------- @@ -984,20 +1065,23 @@ def _check_sample_size(self, *sample_pts: NDArray): f"but only {sample_pts[0].size} samples provided!" ) - def _evaluate(self, sample_pts: NDArray) -> NDArray: + def evaluate(self, sample_pts: NDArray) -> NDArray: """Generate basis functions with given spacing. Parameters ---------- sample_pts - Spacing for basis functions, holding elements on the interval [0, inf), shape (n_pts,). + Spacing for basis functions, holding elements on the interval [0, + inf), shape (n_samples,). Returns ------- basis_funcs - Evaluated exponentially decaying basis functions, - numerically orthogonalized, shape (number of basis, number of samples). + Evaluated exponentially decaying basis functions, numerically + orthogonalized, shape (n_samples, n_basis_funcs) + """ + (sample_pts,) = self._check_evaluate_input(sample_pts) self._check_sample_range(sample_pts) self._check_sample_size(sample_pts) # because of how scipy.linalg.orth works, have to create a matrix of @@ -1008,6 +1092,26 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: np.stack([np.exp(-lam * sample_pts) for lam in self._decay_rates], axis=1) ) + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """Evaluate the basis set on a grid of equi-spaced sample points. + + Parameters + ---------- + n_samples : + The number of samples. + + Returns + ------- + X : + Array of shape (n_samples,) containing the equi-spaced sample + points where we've evaluated the basis. + basis_funcs : + Evaluated exponentially decaying basis functions, numerically + orthogonalized, shape (n_samples, n_basis_funcs) + + """ + return super().evaluate_on_grid(n_samples) + def mspline(x: NDArray, k: int, i: int, T: NDArray): """Compute M-spline basis function. @@ -1066,7 +1170,8 @@ def bspline( Parameters ---------- sample_pts : - An array containing sample points for which B-spline basis needs to be evaluated. + An array containing sample points for which B-spline basis needs to be evaluated, + shape (n_samples,) knots : An array containing knots for the B-spline basis. The knots are sorted in ascending order. order : diff --git a/tests/test_basis.py b/tests/test_basis.py index bd72cd2a..2d7e83f3 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -7,6 +7,7 @@ import utils_testing import nemos.basis as basis +from contextlib import nullcontext as does_not_raise # automatic define user accessible basis and check the methods @@ -69,21 +70,14 @@ def test_non_empty_samples(self, samples): self.cls(5).evaluate(samples) @pytest.mark.parametrize( - "arraylike", [0, [0], (0,), np.array([0]), jax.numpy.array([0])] + "eval_input", [0, [0], (0,), np.array([0]), jax.numpy.array([0])] ) - def test_input_to_evaluate_is_arraylike(self, arraylike): + def test_evaluate_input(self, eval_input): """ Checks that the sample size of the output from the evaluate() method matches the input sample size. """ basis_obj = self.cls(n_basis_funcs=5) - raise_exception = not isinstance( - arraylike, (tuple, list, np.ndarray, jax.numpy.ndarray) - ) - if raise_exception: - with pytest.raises(TypeError, match="Input samples must be array-like"): - basis_obj.evaluate(arraylike) - else: - basis_obj.evaluate(arraylike) + basis_obj.evaluate(eval_input) @pytest.mark.parametrize( "args, sample_size", @@ -160,16 +154,14 @@ def test_number_of_required_inputs_evaluate(self, n_input): Confirms that the evaluate() method correctly handles the number of input samples that are provided. """ basis_obj = self.cls(n_basis_funcs=5) - raise_exception = n_input != basis_obj._n_input_dimensionality inputs = [np.linspace(0, 1, 20)] * n_input - if raise_exception: - with pytest.raises( - ValueError, - match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ " - "inputs,", - ): - basis_obj.evaluate(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate(*inputs) @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) @@ -211,15 +203,13 @@ def test_evaluate_on_grid_input_number(self, n_input): """ basis_obj = self.cls(n_basis_funcs=5) inputs = [10] * n_input - raise_exception = n_input != basis_obj._n_input_dimensionality - if raise_exception: - with pytest.raises( - ValueError, - match=r"Input dimensionality mismatch\. This basis evaluation requires [0-9]+ inputs, " - r"[0-9]+ inputs provided instead.", - ): - basis_obj.evaluate_on_grid(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate_on_grid(*inputs) @@ -237,21 +227,14 @@ def test_non_empty_samples(self, samples): self.cls(5).evaluate(samples) @pytest.mark.parametrize( - "arraylike", [0, [0], (0,), np.array([0]), jax.numpy.array([0])] + "eval_input", [0, [0], (0,), np.array([0]), jax.numpy.array([0])] ) - def test_input_to_evaluate_is_arraylike(self, arraylike): + def test_evaluate_input(self, eval_input): """ Checks that the sample size of the output from the evaluate() method matches the input sample size. """ basis_obj = self.cls(n_basis_funcs=5) - raise_exception = not isinstance( - arraylike, (tuple, list, np.ndarray, jax.numpy.ndarray) - ) - if raise_exception: - with pytest.raises(TypeError, match="Input samples must be array-like"): - basis_obj.evaluate(arraylike) - else: - basis_obj.evaluate(arraylike) + basis_obj.evaluate(eval_input) @pytest.mark.parametrize( "args, sample_size", @@ -327,15 +310,14 @@ def test_number_of_required_inputs_evaluate(self, n_input): Confirms that the evaluate() method correctly handles the number of input samples that are provided. """ basis_obj = self.cls(n_basis_funcs=5) - raise_exception = n_input != basis_obj._n_input_dimensionality inputs = [np.linspace(0, 1, 20)] * n_input - if raise_exception: - with pytest.raises( - ValueError, - match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs,", - ): - basis_obj.evaluate(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate(*inputs) @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) @@ -377,15 +359,13 @@ def test_evaluate_on_grid_input_number(self, n_input): """ basis_obj = self.cls(n_basis_funcs=5) inputs = [10] * n_input - raise_exception = n_input != basis_obj._n_input_dimensionality - if raise_exception: - with pytest.raises( - ValueError, - match=r"Input dimensionality mismatch\. This basis evaluation requires [0-9]+ inputs, " - r"[0-9]+ inputs provided instead.", - ): - basis_obj.evaluate_on_grid(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate_on_grid(*inputs) @@ -403,21 +383,14 @@ def test_non_empty_samples(self, samples): self.cls(5).evaluate(samples) @pytest.mark.parametrize( - "arraylike", [0, [0], (0,), np.array([0]), jax.numpy.array([0])] + "eval_input", [0, [0], (0,), np.array([0]), jax.numpy.array([0])] ) - def test_input_to_evaluate_is_arraylike(self, arraylike): + def test_evaluate_input(self, eval_input): """ Checks that the sample size of the output from the evaluate() method matches the input sample size. """ basis_obj = self.cls(n_basis_funcs=5) - raise_exception = not isinstance( - arraylike, (tuple, list, np.ndarray, jax.numpy.ndarray) - ) - if raise_exception: - with pytest.raises(TypeError, match="Input samples must be array-like"): - basis_obj.evaluate(arraylike) - else: - basis_obj.evaluate(arraylike) + basis_obj.evaluate(eval_input) @pytest.mark.parametrize("n_basis_funcs", [6, 8, 10]) @pytest.mark.parametrize("order", range(1, 6)) @@ -491,15 +464,14 @@ def test_number_of_required_inputs_evaluate(self, n_input): Confirms that the evaluate() method correctly handles the number of input samples that are provided. """ basis_obj = self.cls(n_basis_funcs=5, order=3) - raise_exception = n_input != basis_obj._n_input_dimensionality inputs = [np.linspace(0, 1, 20)] * n_input - if raise_exception: - with pytest.raises( - ValueError, - match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs,", - ): - basis_obj.evaluate(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate(*inputs) @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) @@ -541,15 +513,13 @@ def test_evaluate_on_grid_input_number(self, n_input): """ basis_obj = self.cls(n_basis_funcs=5, order=3) inputs = [10] * n_input - raise_exception = n_input != basis_obj._n_input_dimensionality - if raise_exception: - with pytest.raises( - ValueError, - match=r"Input dimensionality mismatch\. This basis evaluation requires [0-9]+ inputs, " - r"[0-9]+ inputs provided instead.", - ): - basis_obj.evaluate_on_grid(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate_on_grid(*inputs) @@ -568,21 +538,19 @@ def test_non_empty_samples(self, samples): self.cls(5, decay_rates=np.arange(1, 6)).evaluate(samples) @pytest.mark.parametrize( - "arraylike", [0, [0]*6, (0,)*6, np.array([0]*6), jax.numpy.array([0]*6)] + "eval_input", [0, [0]*6, (0,)*6, np.array([0]*6), jax.numpy.array([0]*6)] ) - def test_input_to_evaluate_is_arraylike(self, arraylike): + def test_evaluate_input(self, eval_input): """ Checks that the sample size of the output from the evaluate() method matches the input sample size. """ basis_obj = self.cls(n_basis_funcs=5, decay_rates=np.arange(1, 6)) - raise_exception = not isinstance( - arraylike, (tuple, list, np.ndarray, jax.numpy.ndarray) - ) - if raise_exception: - with pytest.raises(TypeError, match="Input samples must be array-like"): - basis_obj.evaluate(arraylike) + if isinstance(eval_input, int): + # OrthExponentialBasis is special -- cannot accept int input + with pytest.raises(ValueError, match="OrthExponentialBasis requires at least as many samples"): + basis_obj.evaluate(eval_input) else: - basis_obj.evaluate(arraylike) + basis_obj.evaluate(eval_input) @pytest.mark.parametrize("n_basis_funcs", [1, 2, 4, 8]) @pytest.mark.parametrize("sample_size", [10, 1000]) @@ -656,15 +624,14 @@ def test_samples_range_matches_evaluate_requirements(self, sample_range: tuple): def test_number_of_required_inputs_evaluate(self, n_input): """Tests whether the evaluate method correctly processes the number of required inputs.""" basis_obj = self.cls(n_basis_funcs=5, decay_rates=np.arange(1, 6)) - raise_exception = n_input != basis_obj._n_input_dimensionality inputs = [np.linspace(0, 1, 20)] * n_input - if raise_exception: - with pytest.raises( - ValueError, - match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs,", - ): - basis_obj.evaluate(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate(*inputs) @pytest.mark.parametrize("sample_size", [-1, 0, 1, 2, 3, 4, 5, 6, 10, 11, 100]) @@ -705,15 +672,13 @@ def test_evaluate_on_grid_input_number(self, n_input): """Tests whether the evaluate_on_grid method correctly processes the Input dimensionality.""" basis_obj = self.cls(n_basis_funcs=5, decay_rates=np.arange(1, 6)) inputs = [10] * n_input - raise_exception = n_input != basis_obj._n_input_dimensionality - if raise_exception: - with pytest.raises( - ValueError, - match=r"Input dimensionality mismatch\. This basis evaluation requires [0-9]+ inputs, " - r"[0-9]+ inputs provided instead.", - ): - basis_obj.evaluate_on_grid(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate_on_grid(*inputs) @pytest.mark.parametrize( @@ -766,21 +731,14 @@ def test_non_empty_samples(self, samples): self.cls(5).evaluate(samples) @pytest.mark.parametrize( - "arraylike", [0, [0], (0,), np.array([0]), jax.numpy.array([0])] + "eval_input", [0, [0], (0,), np.array([0]), jax.numpy.array([0])] ) - def test_input_to_evaluate_is_arraylike(self, arraylike): + def test_evaluate_input(self, eval_input): """ Checks that the sample size of the output from the evaluate() method matches the input sample size. """ basis_obj = self.cls(n_basis_funcs=5) - raise_exception = not isinstance( - arraylike, (tuple, list, np.ndarray, jax.numpy.ndarray) - ) - if raise_exception: - with pytest.raises(TypeError, match="Input samples must be array-like"): - basis_obj.evaluate(arraylike) - else: - basis_obj.evaluate(arraylike) + basis_obj.evaluate(eval_input) @pytest.mark.parametrize("n_basis_funcs", [6, 8, 10]) @pytest.mark.parametrize("order", range(1, 6)) @@ -869,15 +827,14 @@ def test_number_of_required_inputs_evaluate(self, n_input): Confirms that the evaluate() method correctly handles the number of input samples that are provided. """ basis_obj = self.cls(n_basis_funcs=5, order=3) - raise_exception = n_input != basis_obj._n_input_dimensionality inputs = [np.linspace(0, 1, 20)] * n_input - if raise_exception: - with pytest.raises( - ValueError, - match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs,", - ): - basis_obj.evaluate(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate(*inputs) @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) @@ -923,14 +880,13 @@ def test_evaluate_on_grid_input_number(self, n_input): """ basis_obj = self.cls(n_basis_funcs=5, order=3) inputs = [10] * n_input - raise_exception = n_input != basis_obj._n_input_dimensionality - if raise_exception: - with pytest.raises( - ValueError, - match=r"Input dimensionality mismatch\. This basis evaluation requires [0-9]+ inputs", - ): - basis_obj.evaluate_on_grid(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate_on_grid(*inputs) @@ -948,21 +904,14 @@ def test_non_empty_samples(self, samples): self.cls(5).evaluate(samples) @pytest.mark.parametrize( - "arraylike", [0, [0], (0,), np.array([0]), jax.numpy.array([0])] + "eval_input", [0, [0], (0,), np.array([0]), jax.numpy.array([0])] ) - def test_input_to_evaluate_is_arraylike(self, arraylike): + def test_evaluate_input(self, eval_input): """ Checks that the sample size of the output from the evaluate() method matches the input sample size. """ basis_obj = self.cls(n_basis_funcs=5) - raise_exception = not isinstance( - arraylike, (tuple, list, np.ndarray, jax.numpy.ndarray) - ) - if raise_exception: - with pytest.raises(TypeError, match="Input samples must be array-like"): - basis_obj.evaluate(arraylike) - else: - basis_obj.evaluate(arraylike) + basis_obj.evaluate(eval_input) @pytest.mark.parametrize("n_basis_funcs", [8, 10]) @pytest.mark.parametrize("order", range(2, 6)) @@ -1069,15 +1018,14 @@ def test_number_of_required_inputs_evaluate(self, n_input): Confirms that the evaluate() method correctly handles the number of input samples that are provided. """ basis_obj = self.cls(n_basis_funcs=5, order=3) - raise_exception = n_input != basis_obj._n_input_dimensionality inputs = [np.linspace(0, 1, 20)] * n_input - if raise_exception: - with pytest.raises( - ValueError, - match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs", - ): - basis_obj.evaluate(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate(*inputs) @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) @@ -1123,14 +1071,13 @@ def test_evaluate_on_grid_input_number(self, n_input): """ basis_obj = self.cls(n_basis_funcs=5, order=3) inputs = [10] * n_input - raise_exception = n_input != basis_obj._n_input_dimensionality - if raise_exception: - with pytest.raises( - ValueError, - match=r"Input dimensionality mismatch\. This basis evaluation requires [0-9]+ inputs", - ): - basis_obj.evaluate_on_grid(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate_on_grid(*inputs) @@ -1191,7 +1138,7 @@ def test_non_empty_samples(self, samples): basis_obj.evaluate(*samples) @pytest.mark.parametrize( - "arraylike", + "eval_input", [ [0, 0], [[0], [0]], @@ -1200,20 +1147,12 @@ def test_non_empty_samples(self, samples): [jax.numpy.array([0]), [0]], ], ) - def test_input_to_evaluate_is_arraylike(self, arraylike): + def test_evaluate_input(self, eval_input): """ Checks that the sample size of the output from the evaluate() method matches the input sample size. """ basis_obj = basis.MSplineBasis(5) + basis.MSplineBasis(5) - raise_exception = not all( - isinstance(a, (tuple, list, np.ndarray, jax.numpy.ndarray)) - for a in arraylike - ) - if raise_exception: - with pytest.raises(TypeError, match="Input samples must be array-like"): - basis_obj.evaluate(*arraylike) - else: - basis_obj.evaluate(*arraylike) + basis_obj.evaluate(*eval_input) @pytest.mark.parametrize("n_basis_a", [5, 6]) @pytest.mark.parametrize("n_basis_b", [5, 6]) @@ -1299,18 +1238,13 @@ def test_number_of_required_inputs_evaluate( basis_a_obj = self.instantiate_basis(n_basis_a, basis_a) basis_b_obj = self.instantiate_basis(n_basis_b, basis_b) basis_obj = basis_a_obj + basis_b_obj - raise_exception = ( - n_input - != basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality - ) + required_dim = basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality inputs = [np.linspace(0, 1, 20)] * n_input - if raise_exception: - with pytest.raises( - ValueError, - match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs,", - ): - basis_obj.evaluate(*inputs) + if n_input != required_dim: + expectation = pytest.raises(TypeError, match="Input dimensionality mismatch.") else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate(*inputs) @pytest.mark.parametrize("sample_size", [11, 20]) @@ -1386,18 +1320,12 @@ def test_evaluate_on_grid_input_number( basis_b_obj = self.instantiate_basis(n_basis_b, basis_b) basis_obj = basis_a_obj + basis_b_obj inputs = [20] * n_input - raise_exception = ( - n_input - != basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality - ) - if raise_exception: - with pytest.raises( - ValueError, - match=r"Input dimensionality mismatch\. This basis evaluation requires [0-9]+ inputs, " - r"[0-9]+ inputs provided instead.", - ): - basis_obj.evaluate_on_grid(*inputs) + required_dim = basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality + if n_input != required_dim: + expectation = pytest.raises(TypeError, match="Input dimensionality mismatch.") else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate_on_grid(*inputs) @@ -1418,7 +1346,7 @@ def test_non_empty_samples(self, samples): basis_obj.evaluate(*samples) @pytest.mark.parametrize( - "arraylike", + "eval_input", [ [0, 0], [[0], [0]], @@ -1427,20 +1355,12 @@ def test_non_empty_samples(self, samples): [jax.numpy.array([0]), [0]], ], ) - def test_input_to_evaluate_is_arraylike(self, arraylike): + def test_evaluate_input(self, eval_input): """ Checks that the sample size of the output from the evaluate() method matches the input sample size. """ basis_obj = basis.MSplineBasis(5) * basis.MSplineBasis(5) - raise_exception = not all( - isinstance(a, (tuple, list, np.ndarray, jax.numpy.ndarray)) - for a in arraylike - ) - if raise_exception: - with pytest.raises(TypeError, match="Input samples must be array-like"): - basis_obj.evaluate(*arraylike) - else: - basis_obj.evaluate(*arraylike) + basis_obj.evaluate(*eval_input) @pytest.mark.parametrize("n_basis_a", [5, 6]) @pytest.mark.parametrize("n_basis_b", [5, 6]) @@ -1526,18 +1446,13 @@ def test_number_of_required_inputs_evaluate( basis_a_obj = self.instantiate_basis(n_basis_a, basis_a) basis_b_obj = self.instantiate_basis(n_basis_b, basis_b) basis_obj = basis_a_obj * basis_b_obj - raise_exception = ( - n_input - != basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality - ) + required_dim = basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality inputs = [np.linspace(0, 1, 20)] * n_input - if raise_exception: - with pytest.raises( - ValueError, - match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs,", - ): - basis_obj.evaluate(*inputs) + if n_input != required_dim: + expectation = pytest.raises(TypeError, match="Input dimensionality mismatch.") else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate(*inputs) @pytest.mark.parametrize("sample_size", [11, 20]) @@ -1613,18 +1528,12 @@ def test_evaluate_on_grid_input_number( basis_b_obj = self.instantiate_basis(n_basis_b, basis_b) basis_obj = basis_a_obj * basis_b_obj inputs = [20] * n_input - raise_exception = ( - n_input - != basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality - ) - if raise_exception: - with pytest.raises( - ValueError, - match=r"Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs, " - r"[0-9]+ inputs provided instead.", - ): - basis_obj.evaluate_on_grid(*inputs) + required_dim = basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality + if n_input != required_dim: + expectation = pytest.raises(TypeError, match="Input dimensionality mismatch.") else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate_on_grid(*inputs) @pytest.mark.parametrize("basis_a", [basis.MSplineBasis])