From 6ae6c8e6b8639530202cb5dc4a83f3c7ad6e570f Mon Sep 17 00:00:00 2001 From: "William F. Broderick" Date: Mon, 20 Nov 2023 17:10:51 -0500 Subject: [PATCH 1/7] swaps inheritance around changes where the public method evaluate is defined, so that the user-facing docstring for the simple basis objects (i.e., not Additive or Multiplicative) is clearer --- src/nemos/basis.py | 238 ++++++++++++++++++++++++++++++++------------- 1 file changed, 171 insertions(+), 67 deletions(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index ef36a28b..44249cd3 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -52,19 +52,19 @@ def __init__(self, n_basis_funcs: int) -> None: self._check_n_basis_min() @abc.abstractmethod - def _evaluate(self, *xi: NDArray) -> NDArray: + def evaluate(self, *xi: NDArray) -> NDArray: """ - Evaluate the basis set at the given samples x1,...,xn using the subclass-specific "_evaluate" method. + Evaluate the basis set at the given samples x1,...,xn using the subclass-specific "evaluate" method. Parameters ---------- - *xi: (number of samples, ) + *xi: (n_samples,) The input samples xi[0],...,xi[n] . """ pass @staticmethod - def _get_samples(*n_samples: int) -> Generator[NDArray, ...]: + def _get_samples(*n_samples: int): """Get equi-spaced samples for all the input dimensions. This will be used to evaluate the basis on a grid of @@ -82,35 +82,30 @@ def _get_samples(*n_samples: int) -> Generator[NDArray, ...]: """ return (np.linspace(0, 1, n_samples[k]) for k in range(len(n_samples))) - def evaluate(self, *xi: ArrayLike) -> NDArray: - """ - Evaluate the basis set at the given samples x[0],...,x[n] using the subclass-specific "_evaluate" method. + def _check_evaluate_input(self, *xi: ArrayLike) -> Tuple[NDArray]: + """Check evaluate input. Parameters ---------- xi[0],...,xi[n] : The input samples, each with shape (number of samples, ). - Returns - ------- - : - The generated basis functions. - Raises ------ ValueError - If the time point number is inconsistent between inputs. - If the number of inputs doesn't match what the Basis object requires. - At least one of the samples is empty. - """ - # check that the input is array-like - if any( - not isinstance(x, (list, tuple, np.ndarray, jax.numpy.ndarray)) for x in xi - ): - raise TypeError("Input samples must be array-like!") - # convert to numpy.array of floats - xi = tuple(np.asarray(x, dtype=float) for x in xi) + """ + # check that the input is array-like (i.e., whether we can cast it to + # numeric arrays) + try: + # make sure array is at least 1d (so that we succeed when only + # passed a scalar) + xi = tuple(np.atleast_1d(np.asarray(x, dtype=float)) for x in xi) + except TypeError: + raise TypeError("Input samples must be array-like of floats!") # check for non-empty samples if self._has_zero_samples(tuple(len(x) for x in xi)): @@ -119,10 +114,7 @@ def evaluate(self, *xi: ArrayLike) -> NDArray: # checks on input and outputs self._check_samples_consistency(*xi) self._check_input_dimensionality(xi) - - eval_basis = self._evaluate(*xi) - - return eval_basis + return xi def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: """Evaluate the basis set on a grid of equi-spaced sample points. @@ -134,7 +126,8 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: Parameters ---------- n_samples[0],...,n_samples[n] - The number of samples in each axis of the grid. + The number of samples in each axis of the grid. The length of + n_samples must equal the number of combined bases. Returns ------- @@ -335,24 +328,27 @@ def __init__(self, basis1: Basis, basis2: Basis) -> None: def _check_n_basis_min(self) -> None: pass - def _evaluate(self, *xi: NDArray) -> NDArray: + def evaluate(self, *xi: NDArray) -> NDArray: """ Evaluate the basis at the input samples. Parameters ---------- - xi[0], ..., xi[n] : (number of samples, ) - Tuple of input samples. + xi[0], ..., xi[n] : (n_samples,) + Tuple of input samples, each with the same number of samples. The + number of input arrays must equal the number of combined bases. Returns ------- : The basis function evaluated at the samples, shape (n_samples, n_basis_funcs) + """ + xi = self._check_evaluate_input(*xi) return np.hstack( ( - self._basis1._evaluate(*xi[: self._basis1._n_input_dimensionality]), - self._basis2._evaluate(*xi[self._basis1._n_input_dimensionality :]), + self._basis1.evaluate(*xi[: self._basis1._n_input_dimensionality]), + self._basis2.evaluate(*xi[self._basis1._n_input_dimensionality :]), ) ) @@ -388,24 +384,26 @@ def __init__(self, basis1: Basis, basis2: Basis) -> None: def _check_n_basis_min(self) -> None: pass - def _evaluate(self, *xi: NDArray) -> NDArray: + def evaluate(self, *xi: NDArray) -> NDArray: """ Evaluate the basis at the input samples. Parameters ---------- - xi[0], ..., xi[n] : (number of samples, ) - Tuple of input samples. + xi[0], ..., xi[n] : (n_samples,) + Tuple of input samples, each with the same number of samples. The + number of input arrays must equal the number of combined bases. Returns ------- : The basis function evaluated at the samples, shape (n_samples, n_basis_funcs) """ + xi = self._check_evaluate_input(*xi) return np.array( row_wise_kron( - self._basis1._evaluate(*xi[: self._basis1._n_input_dimensionality]), - self._basis2._evaluate(*xi[self._basis1._n_input_dimensionality :]), + self._basis1.evaluate(*xi[: self._basis1._n_input_dimensionality]), + self._basis2.evaluate(*xi[self._basis1._n_input_dimensionality :]), transpose=False, ) ) @@ -448,7 +446,7 @@ def _generate_knots( Parameters ---------- - sample_pts : (number of samples, ) + sample_pts : (n_samples,) The sample points. perc_low The low percentile value, between [0,1). @@ -531,14 +529,13 @@ class MSplineBasis(SplineBasis): def __init__(self, n_basis_funcs: int, order: int = 2) -> None: super().__init__(n_basis_funcs, order) - def _evaluate(self, sample_pts: NDArray) -> NDArray: + def evaluate(self, sample_pts: NDArray) -> NDArray: """Generate basis functions with given spacing. Parameters ---------- sample_pts : - Spacing for basis functions, holding elements on the interval [min(sample_pts), - max(sample_pts)], shape (number of samples, ) + Spacing for basis functions, shape (n_samples,) Returns ------- @@ -546,6 +543,7 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: Evaluated spline basis functions, shape (n_samples, n_basis_funcs). """ + sample_pts, = self._check_evaluate_input(sample_pts) # add knots if not passed knot_locs = self._generate_knots( sample_pts, perc_low=0.0, perc_high=1.0, is_cyclic=False @@ -559,6 +557,24 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: axis=1, ) + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """Evaluate the M-spline basis set on a grid of equi-spaced sample points. + + Parameters + ---------- + n_samples : + The number of samples. + + Returns + ------- + X : + Array of shape (n_samples,) containing the equi-spaced sample + points where we've evaluated the basis. + basis_funcs : + Raised cosine basis functions, shape (n_samples, n_basis_funcs) + + """ + return super().evaluate_on_grid(n_samples) class BSplineBasis(SplineBasis): """ @@ -589,30 +605,31 @@ class BSplineBasis(SplineBasis): def __init__(self, n_basis_funcs: int, order: int = 2): super().__init__(n_basis_funcs, order=order) - def _evaluate(self, sample_pts: NDArray) -> NDArray: + def evaluate(self, sample_pts: NDArray) -> NDArray: """ Evaluate the B-spline basis functions with given sample points. Parameters ---------- sample_pts : - The sample points at which the B-spline is evaluated. + The sample points at which the B-spline is evaluated, shape (n_samples,). Returns ------- - NDArray + basis_funcs : The basis function evaluated at the samples, shape (n_samples, n_basis_funcs) Raises ------ AssertionError - If the sample points are not within the B-spline knots range unless `outer_ok=True`. + If the sample points are not within the B-spline knots. Notes ----- The evaluation is performed by looping over each element and using `splev` from SciPy to compute the basis values. """ + sample_pts, = self._check_evaluate_input(sample_pts) # add knots knot_locs = self._generate_knots(sample_pts, 0.0, 1.0) @@ -623,6 +640,29 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: return basis_eval + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """Evaluate the B-spline basis set on a grid of equi-spaced sample points. + + Parameters + ---------- + n_samples : + The number of samples. + + Returns + ------- + X : + Array of shape (n_samples,) containing the equi-spaced sample + points where we've evaluated the basis. + basis_funcs : + Raised cosine basis functions, shape (n_samples, n_basis_funcs) + + Notes + ----- + The evaluation is performed by looping over each element and using `splev` from + SciPy to compute the basis values. + """ + return super().evaluate_on_grid(n_samples) + class CyclicBSplineBasis(SplineBasis): """ @@ -653,30 +693,27 @@ def __init__(self, n_basis_funcs: int, order: int = 2): f"order {self.order} specified instead!" ) - def _evaluate(self, sample_pts: NDArray) -> NDArray: - """ - Evaluate the B-spline basis functions with given sample points. + def evaluate(self, sample_pts: NDArray) -> NDArray: + """Evaluate the Cyclic B-spline basis functions with given sample points. Parameters ---------- sample_pts : - The sample points at which the B-spline is evaluated. Must be a tuple of length 1. + The sample points at which the cyclic B-spline is evaluated, shape + (n_samples,). Returns ------- - NDArray + basis_funcs : The basis function evaluated at the samples, shape (n_samples, n_basis_funcs) - Raises - ------ - AssertionError - If the sample points are not within the B-spline knots range unless `outer_ok=True`. - Notes ----- The evaluation is performed by looping over each element and using `splev` from SciPy to compute the basis values. + """ + sample_pts, = self._check_evaluate_input(sample_pts) knot_locs = self._generate_knots(sample_pts, 0.0, 1.0, is_cyclic=True) @@ -710,6 +747,28 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: return basis_eval + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """Evaluate the Cyclic B-spline basis set on a grid of equi-spaced sample points. + + Parameters + ---------- + n_samples : + The number of samples. + + Returns + ------- + X : + Array of shape (n_samples,) containing the equi-spaced sample + points where we've evaluated the basis. + basis_funcs : + Raised cosine basis functions, shape (n_samples, n_basis_funcs) + + Notes + ----- + The evaluation is performed by looping over each element and using `splev` from + SciPy to compute the basis values. + """ + return super().evaluate_on_grid(n_samples) class RaisedCosineBasis(Basis, abc.ABC): def __init__(self, n_basis_funcs: int) -> None: @@ -724,21 +783,21 @@ def _transform_samples(self, sample_pts: NDArray) -> NDArray: Parameters ---------- sample_pts : - The sample points to be transformed, shape (number of samples, ). + The sample points to be transformed, shape (n_samples,). """ pass - def _evaluate(self, sample_pts: NDArray) -> NDArray: + def evaluate(self, sample_pts: NDArray) -> NDArray: """Generate basis functions with given samples. Parameters ---------- sample_pts : (number of samples,) - Spacing for basis functions, holding elements on interval [0, - 1). A good default is - ``nmo.sample_points.raised_cosine_log`` for log spacing (as used in - [2]_) or ``nmo.sample_points.raised_cosine_linear`` for linear - spacing. + Spacing for basis functions, holding elements on interval [0, 1). A + good default is ``nmo.sample_points.raised_cosine_log`` for log + spacing (as used in [2]_) or + ``nmo.sample_points.raised_cosine_linear`` for linear spacing. + Shape (n_samples,). Returns ------- @@ -749,7 +808,9 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: ------ ValueError If the sample provided do not lie in [0,1]. + """ + sample_pts, = self._check_evaluate_input(sample_pts) if any(sample_pts < 0) or any(sample_pts > 1): raise ValueError("Sample points for RaisedCosine basis must lie in [0,1]!") @@ -764,6 +825,25 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: return basis_funcs + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """Evaluate the basis set on a grid of equi-spaced sample points. + + Parameters + ---------- + n_samples : + The number of samples. + + Returns + ------- + X : + Array of shape (n_samples,) containing the equi-spaced sample + points where we've evaluated the basis. + basis_funcs : + Raised cosine basis functions, shape (n_samples, n_basis_funcs) + + """ + return super().evaluate_on_grid(n_samples) + class RaisedCosineBasisLinear(RaisedCosineBasis): """Linearly-spaced raised cosine basis functions used by Pillow et al. [2]_. @@ -794,7 +874,7 @@ def _transform_samples(self, sample_pts: NDArray) -> NDArray: Parameters ---------- sample_pts : - The sample points used for evaluating the splines, shape (number of samples, ) + The sample points used for evaluating the splines, shape (n_samples,) Returns ------- @@ -851,7 +931,7 @@ def _transform_samples(self, sample_pts: NDArray) -> NDArray: Parameters ---------- sample_pts : - The sample points used for evaluating the splines, shape (number of samples, ). + The sample points used for evaluating the splines, shape (n_samples,). Returns ------- @@ -984,20 +1064,23 @@ def _check_sample_size(self, *sample_pts: NDArray): f"but only {sample_pts[0].size} samples provided!" ) - def _evaluate(self, sample_pts: NDArray) -> NDArray: + def evaluate(self, sample_pts: NDArray) -> NDArray: """Generate basis functions with given spacing. Parameters ---------- sample_pts - Spacing for basis functions, holding elements on the interval [0, inf), shape (n_pts,). + Spacing for basis functions, holding elements on the interval [0, + inf), shape (n_samples,). Returns ------- basis_funcs - Evaluated exponentially decaying basis functions, - numerically orthogonalized, shape (number of basis, number of samples). + Evaluated exponentially decaying basis functions, numerically + orthogonalized, shape (n_samples, n_basis_funcs) + """ + sample_pts, = self._check_evaluate_input(sample_pts) self._check_sample_range(sample_pts) self._check_sample_size(sample_pts) # because of how scipy.linalg.orth works, have to create a matrix of @@ -1008,6 +1091,26 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: np.stack([np.exp(-lam * sample_pts) for lam in self._decay_rates], axis=1) ) + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """Evaluate the basis set on a grid of equi-spaced sample points. + + Parameters + ---------- + n_samples : + The number of samples. + + Returns + ------- + X : + Array of shape (n_samples,) containing the equi-spaced sample + points where we've evaluated the basis. + basis_funcs : + Evaluated exponentially decaying basis functions, numerically + orthogonalized, shape (n_samples, n_basis_funcs) + + """ + return super().evaluate_on_grid(n_samples) + def mspline(x: NDArray, k: int, i: int, T: NDArray): """Compute M-spline basis function. @@ -1066,7 +1169,8 @@ def bspline( Parameters ---------- sample_pts : - An array containing sample points for which B-spline basis needs to be evaluated. + An array containing sample points for which B-spline basis needs to be evaluated, + shape (n_samples,) knots : An array containing knots for the B-spline basis. The knots are sorted in ascending order. order : From 5cfdffc60f4d521321df157d44dc85c85e749498 Mon Sep 17 00:00:00 2001 From: "William F. Broderick" Date: Mon, 20 Nov 2023 17:20:28 -0500 Subject: [PATCH 2/7] didn't mean to remove that... --- src/nemos/basis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 44249cd3..0481dc92 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -64,7 +64,7 @@ def evaluate(self, *xi: NDArray) -> NDArray: pass @staticmethod - def _get_samples(*n_samples: int): + def _get_samples(*n_samples: int) -> Generator[NDArray, ...]: """Get equi-spaced samples for all the input dimensions. This will be used to evaluate the basis on a grid of From 4abe959c0e003fe2963feb9f69a89cd5c948fa84 Mon Sep 17 00:00:00 2001 From: "William F. Broderick" Date: Mon, 20 Nov 2023 17:23:19 -0500 Subject: [PATCH 3/7] runs flake8 on code --- src/nemos/basis.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 0481dc92..efe8fa70 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -6,7 +6,6 @@ import abc from typing import Generator, Tuple -import jax.numpy import numpy as np import scipy.linalg from numpy.typing import ArrayLike, NDArray From 3400d0968f54b5f22fcc5b6afd1b12f7025fb924 Mon Sep 17 00:00:00 2001 From: "William F. Broderick" Date: Mon, 20 Nov 2023 17:23:28 -0500 Subject: [PATCH 4/7] runs black on code --- src/nemos/basis.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index efe8fa70..4ccff352 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -542,7 +542,7 @@ def evaluate(self, sample_pts: NDArray) -> NDArray: Evaluated spline basis functions, shape (n_samples, n_basis_funcs). """ - sample_pts, = self._check_evaluate_input(sample_pts) + (sample_pts,) = self._check_evaluate_input(sample_pts) # add knots if not passed knot_locs = self._generate_knots( sample_pts, perc_low=0.0, perc_high=1.0, is_cyclic=False @@ -575,6 +575,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ return super().evaluate_on_grid(n_samples) + class BSplineBasis(SplineBasis): """ B-spline 1-dimensional basis functions. @@ -628,7 +629,7 @@ def evaluate(self, sample_pts: NDArray) -> NDArray: The evaluation is performed by looping over each element and using `splev` from SciPy to compute the basis values. """ - sample_pts, = self._check_evaluate_input(sample_pts) + (sample_pts,) = self._check_evaluate_input(sample_pts) # add knots knot_locs = self._generate_knots(sample_pts, 0.0, 1.0) @@ -712,7 +713,7 @@ def evaluate(self, sample_pts: NDArray) -> NDArray: SciPy to compute the basis values. """ - sample_pts, = self._check_evaluate_input(sample_pts) + (sample_pts,) = self._check_evaluate_input(sample_pts) knot_locs = self._generate_knots(sample_pts, 0.0, 1.0, is_cyclic=True) @@ -769,6 +770,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ return super().evaluate_on_grid(n_samples) + class RaisedCosineBasis(Basis, abc.ABC): def __init__(self, n_basis_funcs: int) -> None: super().__init__(n_basis_funcs) @@ -809,7 +811,7 @@ def evaluate(self, sample_pts: NDArray) -> NDArray: If the sample provided do not lie in [0,1]. """ - sample_pts, = self._check_evaluate_input(sample_pts) + (sample_pts,) = self._check_evaluate_input(sample_pts) if any(sample_pts < 0) or any(sample_pts > 1): raise ValueError("Sample points for RaisedCosine basis must lie in [0,1]!") @@ -1079,7 +1081,7 @@ def evaluate(self, sample_pts: NDArray) -> NDArray: orthogonalized, shape (n_samples, n_basis_funcs) """ - sample_pts, = self._check_evaluate_input(sample_pts) + (sample_pts,) = self._check_evaluate_input(sample_pts) self._check_sample_range(sample_pts) self._check_sample_size(sample_pts) # because of how scipy.linalg.orth works, have to create a matrix of From 664bc4c8d1110731c58bb0d8f689b629ddca5295 Mon Sep 17 00:00:00 2001 From: "William F. Broderick" Date: Mon, 11 Dec 2023 15:54:51 -0500 Subject: [PATCH 5/7] test input arraylike -> test_evaluate_input basis.evaluate now supports single int inputs, update tests to reflect that --- tests/test_basis.py | 109 ++++++++++++-------------------------------- 1 file changed, 28 insertions(+), 81 deletions(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index bd72cd2a..2e66a6cb 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -69,21 +69,14 @@ def test_non_empty_samples(self, samples): self.cls(5).evaluate(samples) @pytest.mark.parametrize( - "arraylike", [0, [0], (0,), np.array([0]), jax.numpy.array([0])] + "eval_input", [0, [0], (0,), np.array([0]), jax.numpy.array([0])] ) - def test_input_to_evaluate_is_arraylike(self, arraylike): + def test_evaluate_input(self, eval_input): """ Checks that the sample size of the output from the evaluate() method matches the input sample size. """ basis_obj = self.cls(n_basis_funcs=5) - raise_exception = not isinstance( - arraylike, (tuple, list, np.ndarray, jax.numpy.ndarray) - ) - if raise_exception: - with pytest.raises(TypeError, match="Input samples must be array-like"): - basis_obj.evaluate(arraylike) - else: - basis_obj.evaluate(arraylike) + basis_obj.evaluate(eval_input) @pytest.mark.parametrize( "args, sample_size", @@ -237,21 +230,14 @@ def test_non_empty_samples(self, samples): self.cls(5).evaluate(samples) @pytest.mark.parametrize( - "arraylike", [0, [0], (0,), np.array([0]), jax.numpy.array([0])] + "eval_input", [0, [0], (0,), np.array([0]), jax.numpy.array([0])] ) - def test_input_to_evaluate_is_arraylike(self, arraylike): + def test_evaluate_input(self, eval_input): """ Checks that the sample size of the output from the evaluate() method matches the input sample size. """ basis_obj = self.cls(n_basis_funcs=5) - raise_exception = not isinstance( - arraylike, (tuple, list, np.ndarray, jax.numpy.ndarray) - ) - if raise_exception: - with pytest.raises(TypeError, match="Input samples must be array-like"): - basis_obj.evaluate(arraylike) - else: - basis_obj.evaluate(arraylike) + basis_obj.evaluate(eval_input) @pytest.mark.parametrize( "args, sample_size", @@ -403,21 +389,14 @@ def test_non_empty_samples(self, samples): self.cls(5).evaluate(samples) @pytest.mark.parametrize( - "arraylike", [0, [0], (0,), np.array([0]), jax.numpy.array([0])] + "eval_input", [0, [0], (0,), np.array([0]), jax.numpy.array([0])] ) - def test_input_to_evaluate_is_arraylike(self, arraylike): + def test_evaluate_input(self, eval_input): """ Checks that the sample size of the output from the evaluate() method matches the input sample size. """ basis_obj = self.cls(n_basis_funcs=5) - raise_exception = not isinstance( - arraylike, (tuple, list, np.ndarray, jax.numpy.ndarray) - ) - if raise_exception: - with pytest.raises(TypeError, match="Input samples must be array-like"): - basis_obj.evaluate(arraylike) - else: - basis_obj.evaluate(arraylike) + basis_obj.evaluate(eval_input) @pytest.mark.parametrize("n_basis_funcs", [6, 8, 10]) @pytest.mark.parametrize("order", range(1, 6)) @@ -568,21 +547,19 @@ def test_non_empty_samples(self, samples): self.cls(5, decay_rates=np.arange(1, 6)).evaluate(samples) @pytest.mark.parametrize( - "arraylike", [0, [0]*6, (0,)*6, np.array([0]*6), jax.numpy.array([0]*6)] + "eval_input", [0, [0]*6, (0,)*6, np.array([0]*6), jax.numpy.array([0]*6)] ) - def test_input_to_evaluate_is_arraylike(self, arraylike): + def test_evaluate_input(self, eval_input): """ Checks that the sample size of the output from the evaluate() method matches the input sample size. """ basis_obj = self.cls(n_basis_funcs=5, decay_rates=np.arange(1, 6)) - raise_exception = not isinstance( - arraylike, (tuple, list, np.ndarray, jax.numpy.ndarray) - ) - if raise_exception: - with pytest.raises(TypeError, match="Input samples must be array-like"): - basis_obj.evaluate(arraylike) + if isinstance(eval_input, int): + # OrthExponentialBasis is special -- cannot accept int input + with pytest.raises(ValueError, match="OrthExponentialBasis requires at least as many samples"): + basis_obj.evaluate(eval_input) else: - basis_obj.evaluate(arraylike) + basis_obj.evaluate(eval_input) @pytest.mark.parametrize("n_basis_funcs", [1, 2, 4, 8]) @pytest.mark.parametrize("sample_size", [10, 1000]) @@ -766,21 +743,14 @@ def test_non_empty_samples(self, samples): self.cls(5).evaluate(samples) @pytest.mark.parametrize( - "arraylike", [0, [0], (0,), np.array([0]), jax.numpy.array([0])] + "eval_input", [0, [0], (0,), np.array([0]), jax.numpy.array([0])] ) - def test_input_to_evaluate_is_arraylike(self, arraylike): + def test_evaluate_input(self, eval_input): """ Checks that the sample size of the output from the evaluate() method matches the input sample size. """ basis_obj = self.cls(n_basis_funcs=5) - raise_exception = not isinstance( - arraylike, (tuple, list, np.ndarray, jax.numpy.ndarray) - ) - if raise_exception: - with pytest.raises(TypeError, match="Input samples must be array-like"): - basis_obj.evaluate(arraylike) - else: - basis_obj.evaluate(arraylike) + basis_obj.evaluate(eval_input) @pytest.mark.parametrize("n_basis_funcs", [6, 8, 10]) @pytest.mark.parametrize("order", range(1, 6)) @@ -948,21 +918,14 @@ def test_non_empty_samples(self, samples): self.cls(5).evaluate(samples) @pytest.mark.parametrize( - "arraylike", [0, [0], (0,), np.array([0]), jax.numpy.array([0])] + "eval_input", [0, [0], (0,), np.array([0]), jax.numpy.array([0])] ) - def test_input_to_evaluate_is_arraylike(self, arraylike): + def test_evaluate_input(self, eval_input): """ Checks that the sample size of the output from the evaluate() method matches the input sample size. """ basis_obj = self.cls(n_basis_funcs=5) - raise_exception = not isinstance( - arraylike, (tuple, list, np.ndarray, jax.numpy.ndarray) - ) - if raise_exception: - with pytest.raises(TypeError, match="Input samples must be array-like"): - basis_obj.evaluate(arraylike) - else: - basis_obj.evaluate(arraylike) + basis_obj.evaluate(eval_input) @pytest.mark.parametrize("n_basis_funcs", [8, 10]) @pytest.mark.parametrize("order", range(2, 6)) @@ -1191,7 +1154,7 @@ def test_non_empty_samples(self, samples): basis_obj.evaluate(*samples) @pytest.mark.parametrize( - "arraylike", + "eval_input", [ [0, 0], [[0], [0]], @@ -1200,20 +1163,12 @@ def test_non_empty_samples(self, samples): [jax.numpy.array([0]), [0]], ], ) - def test_input_to_evaluate_is_arraylike(self, arraylike): + def test_evaluate_input(self, eval_input): """ Checks that the sample size of the output from the evaluate() method matches the input sample size. """ basis_obj = basis.MSplineBasis(5) + basis.MSplineBasis(5) - raise_exception = not all( - isinstance(a, (tuple, list, np.ndarray, jax.numpy.ndarray)) - for a in arraylike - ) - if raise_exception: - with pytest.raises(TypeError, match="Input samples must be array-like"): - basis_obj.evaluate(*arraylike) - else: - basis_obj.evaluate(*arraylike) + basis_obj.evaluate(*eval_input) @pytest.mark.parametrize("n_basis_a", [5, 6]) @pytest.mark.parametrize("n_basis_b", [5, 6]) @@ -1418,7 +1373,7 @@ def test_non_empty_samples(self, samples): basis_obj.evaluate(*samples) @pytest.mark.parametrize( - "arraylike", + "eval_input", [ [0, 0], [[0], [0]], @@ -1427,20 +1382,12 @@ def test_non_empty_samples(self, samples): [jax.numpy.array([0]), [0]], ], ) - def test_input_to_evaluate_is_arraylike(self, arraylike): + def test_evaluate_input(self, eval_input): """ Checks that the sample size of the output from the evaluate() method matches the input sample size. """ basis_obj = basis.MSplineBasis(5) * basis.MSplineBasis(5) - raise_exception = not all( - isinstance(a, (tuple, list, np.ndarray, jax.numpy.ndarray)) - for a in arraylike - ) - if raise_exception: - with pytest.raises(TypeError, match="Input samples must be array-like"): - basis_obj.evaluate(*arraylike) - else: - basis_obj.evaluate(*arraylike) + basis_obj.evaluate(*eval_input) @pytest.mark.parametrize("n_basis_a", [5, 6]) @pytest.mark.parametrize("n_basis_b", [5, 6]) From 231c93140a805c31a333a70489301831f40a6a7f Mon Sep 17 00:00:00 2001 From: "William F. Broderick" Date: Mon, 11 Dec 2023 16:18:59 -0500 Subject: [PATCH 6/7] updates Error type for _check_input_dimensionality this way the error raised here (which can only be hit by AdditiveBasis or MultiplicativeBasis) matches the error raised by the single bases objects: MSplineBasis(5).evaluate(5, 5) raises typeError (MSplineBasis(5)+MSplineBasis(5)).evaluate(5) also raises typeError --- src/nemos/basis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 4ccff352..a837f71b 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -188,7 +188,7 @@ def _check_input_dimensionality(self, xi: Tuple) -> None: If the number of inputs doesn't match what the Basis object requires. """ if len(xi) != self._n_input_dimensionality: - raise ValueError( + raise TypeError( f"Input dimensionality mismatch. This basis evaluation requires {self._n_input_dimensionality} inputs, " f"{len(xi)} inputs provided instead." ) From eca9fb9a8a5799ef907f172c5ec47e6452a774c7 Mon Sep 17 00:00:00 2001 From: "William F. Broderick" Date: Mon, 11 Dec 2023 16:20:47 -0500 Subject: [PATCH 7/7] updates tests so they pass --- tests/test_basis.py | 224 ++++++++++++++++++-------------------------- 1 file changed, 93 insertions(+), 131 deletions(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index 2e66a6cb..2d7e83f3 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -7,6 +7,7 @@ import utils_testing import nemos.basis as basis +from contextlib import nullcontext as does_not_raise # automatic define user accessible basis and check the methods @@ -153,16 +154,14 @@ def test_number_of_required_inputs_evaluate(self, n_input): Confirms that the evaluate() method correctly handles the number of input samples that are provided. """ basis_obj = self.cls(n_basis_funcs=5) - raise_exception = n_input != basis_obj._n_input_dimensionality inputs = [np.linspace(0, 1, 20)] * n_input - if raise_exception: - with pytest.raises( - ValueError, - match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ " - "inputs,", - ): - basis_obj.evaluate(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate(*inputs) @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) @@ -204,15 +203,13 @@ def test_evaluate_on_grid_input_number(self, n_input): """ basis_obj = self.cls(n_basis_funcs=5) inputs = [10] * n_input - raise_exception = n_input != basis_obj._n_input_dimensionality - if raise_exception: - with pytest.raises( - ValueError, - match=r"Input dimensionality mismatch\. This basis evaluation requires [0-9]+ inputs, " - r"[0-9]+ inputs provided instead.", - ): - basis_obj.evaluate_on_grid(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate_on_grid(*inputs) @@ -313,15 +310,14 @@ def test_number_of_required_inputs_evaluate(self, n_input): Confirms that the evaluate() method correctly handles the number of input samples that are provided. """ basis_obj = self.cls(n_basis_funcs=5) - raise_exception = n_input != basis_obj._n_input_dimensionality inputs = [np.linspace(0, 1, 20)] * n_input - if raise_exception: - with pytest.raises( - ValueError, - match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs,", - ): - basis_obj.evaluate(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate(*inputs) @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) @@ -363,15 +359,13 @@ def test_evaluate_on_grid_input_number(self, n_input): """ basis_obj = self.cls(n_basis_funcs=5) inputs = [10] * n_input - raise_exception = n_input != basis_obj._n_input_dimensionality - if raise_exception: - with pytest.raises( - ValueError, - match=r"Input dimensionality mismatch\. This basis evaluation requires [0-9]+ inputs, " - r"[0-9]+ inputs provided instead.", - ): - basis_obj.evaluate_on_grid(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate_on_grid(*inputs) @@ -470,15 +464,14 @@ def test_number_of_required_inputs_evaluate(self, n_input): Confirms that the evaluate() method correctly handles the number of input samples that are provided. """ basis_obj = self.cls(n_basis_funcs=5, order=3) - raise_exception = n_input != basis_obj._n_input_dimensionality inputs = [np.linspace(0, 1, 20)] * n_input - if raise_exception: - with pytest.raises( - ValueError, - match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs,", - ): - basis_obj.evaluate(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate(*inputs) @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) @@ -520,15 +513,13 @@ def test_evaluate_on_grid_input_number(self, n_input): """ basis_obj = self.cls(n_basis_funcs=5, order=3) inputs = [10] * n_input - raise_exception = n_input != basis_obj._n_input_dimensionality - if raise_exception: - with pytest.raises( - ValueError, - match=r"Input dimensionality mismatch\. This basis evaluation requires [0-9]+ inputs, " - r"[0-9]+ inputs provided instead.", - ): - basis_obj.evaluate_on_grid(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate_on_grid(*inputs) @@ -633,15 +624,14 @@ def test_samples_range_matches_evaluate_requirements(self, sample_range: tuple): def test_number_of_required_inputs_evaluate(self, n_input): """Tests whether the evaluate method correctly processes the number of required inputs.""" basis_obj = self.cls(n_basis_funcs=5, decay_rates=np.arange(1, 6)) - raise_exception = n_input != basis_obj._n_input_dimensionality inputs = [np.linspace(0, 1, 20)] * n_input - if raise_exception: - with pytest.raises( - ValueError, - match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs,", - ): - basis_obj.evaluate(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate(*inputs) @pytest.mark.parametrize("sample_size", [-1, 0, 1, 2, 3, 4, 5, 6, 10, 11, 100]) @@ -682,15 +672,13 @@ def test_evaluate_on_grid_input_number(self, n_input): """Tests whether the evaluate_on_grid method correctly processes the Input dimensionality.""" basis_obj = self.cls(n_basis_funcs=5, decay_rates=np.arange(1, 6)) inputs = [10] * n_input - raise_exception = n_input != basis_obj._n_input_dimensionality - if raise_exception: - with pytest.raises( - ValueError, - match=r"Input dimensionality mismatch\. This basis evaluation requires [0-9]+ inputs, " - r"[0-9]+ inputs provided instead.", - ): - basis_obj.evaluate_on_grid(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate_on_grid(*inputs) @pytest.mark.parametrize( @@ -839,15 +827,14 @@ def test_number_of_required_inputs_evaluate(self, n_input): Confirms that the evaluate() method correctly handles the number of input samples that are provided. """ basis_obj = self.cls(n_basis_funcs=5, order=3) - raise_exception = n_input != basis_obj._n_input_dimensionality inputs = [np.linspace(0, 1, 20)] * n_input - if raise_exception: - with pytest.raises( - ValueError, - match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs,", - ): - basis_obj.evaluate(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate(*inputs) @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) @@ -893,14 +880,13 @@ def test_evaluate_on_grid_input_number(self, n_input): """ basis_obj = self.cls(n_basis_funcs=5, order=3) inputs = [10] * n_input - raise_exception = n_input != basis_obj._n_input_dimensionality - if raise_exception: - with pytest.raises( - ValueError, - match=r"Input dimensionality mismatch\. This basis evaluation requires [0-9]+ inputs", - ): - basis_obj.evaluate_on_grid(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate_on_grid(*inputs) @@ -1032,15 +1018,14 @@ def test_number_of_required_inputs_evaluate(self, n_input): Confirms that the evaluate() method correctly handles the number of input samples that are provided. """ basis_obj = self.cls(n_basis_funcs=5, order=3) - raise_exception = n_input != basis_obj._n_input_dimensionality inputs = [np.linspace(0, 1, 20)] * n_input - if raise_exception: - with pytest.raises( - ValueError, - match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs", - ): - basis_obj.evaluate(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate(*inputs) @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) @@ -1086,14 +1071,13 @@ def test_evaluate_on_grid_input_number(self, n_input): """ basis_obj = self.cls(n_basis_funcs=5, order=3) inputs = [10] * n_input - raise_exception = n_input != basis_obj._n_input_dimensionality - if raise_exception: - with pytest.raises( - ValueError, - match=r"Input dimensionality mismatch\. This basis evaluation requires [0-9]+ inputs", - ): - basis_obj.evaluate_on_grid(*inputs) + if n_input == 0: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate_on_grid(*inputs) @@ -1254,18 +1238,13 @@ def test_number_of_required_inputs_evaluate( basis_a_obj = self.instantiate_basis(n_basis_a, basis_a) basis_b_obj = self.instantiate_basis(n_basis_b, basis_b) basis_obj = basis_a_obj + basis_b_obj - raise_exception = ( - n_input - != basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality - ) + required_dim = basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality inputs = [np.linspace(0, 1, 20)] * n_input - if raise_exception: - with pytest.raises( - ValueError, - match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs,", - ): - basis_obj.evaluate(*inputs) + if n_input != required_dim: + expectation = pytest.raises(TypeError, match="Input dimensionality mismatch.") else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate(*inputs) @pytest.mark.parametrize("sample_size", [11, 20]) @@ -1341,18 +1320,12 @@ def test_evaluate_on_grid_input_number( basis_b_obj = self.instantiate_basis(n_basis_b, basis_b) basis_obj = basis_a_obj + basis_b_obj inputs = [20] * n_input - raise_exception = ( - n_input - != basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality - ) - if raise_exception: - with pytest.raises( - ValueError, - match=r"Input dimensionality mismatch\. This basis evaluation requires [0-9]+ inputs, " - r"[0-9]+ inputs provided instead.", - ): - basis_obj.evaluate_on_grid(*inputs) + required_dim = basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality + if n_input != required_dim: + expectation = pytest.raises(TypeError, match="Input dimensionality mismatch.") else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate_on_grid(*inputs) @@ -1473,18 +1446,13 @@ def test_number_of_required_inputs_evaluate( basis_a_obj = self.instantiate_basis(n_basis_a, basis_a) basis_b_obj = self.instantiate_basis(n_basis_b, basis_b) basis_obj = basis_a_obj * basis_b_obj - raise_exception = ( - n_input - != basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality - ) + required_dim = basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality inputs = [np.linspace(0, 1, 20)] * n_input - if raise_exception: - with pytest.raises( - ValueError, - match="Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs,", - ): - basis_obj.evaluate(*inputs) + if n_input != required_dim: + expectation = pytest.raises(TypeError, match="Input dimensionality mismatch.") else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate(*inputs) @pytest.mark.parametrize("sample_size", [11, 20]) @@ -1560,18 +1528,12 @@ def test_evaluate_on_grid_input_number( basis_b_obj = self.instantiate_basis(n_basis_b, basis_b) basis_obj = basis_a_obj * basis_b_obj inputs = [20] * n_input - raise_exception = ( - n_input - != basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality - ) - if raise_exception: - with pytest.raises( - ValueError, - match=r"Input dimensionality mismatch. This basis evaluation requires [0-9]+ inputs, " - r"[0-9]+ inputs provided instead.", - ): - basis_obj.evaluate_on_grid(*inputs) + required_dim = basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality + if n_input != required_dim: + expectation = pytest.raises(TypeError, match="Input dimensionality mismatch.") else: + expectation = does_not_raise() + with expectation: basis_obj.evaluate_on_grid(*inputs) @pytest.mark.parametrize("basis_a", [basis.MSplineBasis])