From 6ae6c8e6b8639530202cb5dc4a83f3c7ad6e570f Mon Sep 17 00:00:00 2001 From: "William F. Broderick" Date: Mon, 20 Nov 2023 17:10:51 -0500 Subject: [PATCH] swaps inheritance around changes where the public method evaluate is defined, so that the user-facing docstring for the simple basis objects (i.e., not Additive or Multiplicative) is clearer --- src/nemos/basis.py | 238 ++++++++++++++++++++++++++++++++------------- 1 file changed, 171 insertions(+), 67 deletions(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index ef36a28b..44249cd3 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -52,19 +52,19 @@ def __init__(self, n_basis_funcs: int) -> None: self._check_n_basis_min() @abc.abstractmethod - def _evaluate(self, *xi: NDArray) -> NDArray: + def evaluate(self, *xi: NDArray) -> NDArray: """ - Evaluate the basis set at the given samples x1,...,xn using the subclass-specific "_evaluate" method. + Evaluate the basis set at the given samples x1,...,xn using the subclass-specific "evaluate" method. Parameters ---------- - *xi: (number of samples, ) + *xi: (n_samples,) The input samples xi[0],...,xi[n] . """ pass @staticmethod - def _get_samples(*n_samples: int) -> Generator[NDArray, ...]: + def _get_samples(*n_samples: int): """Get equi-spaced samples for all the input dimensions. This will be used to evaluate the basis on a grid of @@ -82,35 +82,30 @@ def _get_samples(*n_samples: int) -> Generator[NDArray, ...]: """ return (np.linspace(0, 1, n_samples[k]) for k in range(len(n_samples))) - def evaluate(self, *xi: ArrayLike) -> NDArray: - """ - Evaluate the basis set at the given samples x[0],...,x[n] using the subclass-specific "_evaluate" method. + def _check_evaluate_input(self, *xi: ArrayLike) -> Tuple[NDArray]: + """Check evaluate input. Parameters ---------- xi[0],...,xi[n] : The input samples, each with shape (number of samples, ). - Returns - ------- - : - The generated basis functions. - Raises ------ ValueError - If the time point number is inconsistent between inputs. - If the number of inputs doesn't match what the Basis object requires. - At least one of the samples is empty. - """ - # check that the input is array-like - if any( - not isinstance(x, (list, tuple, np.ndarray, jax.numpy.ndarray)) for x in xi - ): - raise TypeError("Input samples must be array-like!") - # convert to numpy.array of floats - xi = tuple(np.asarray(x, dtype=float) for x in xi) + """ + # check that the input is array-like (i.e., whether we can cast it to + # numeric arrays) + try: + # make sure array is at least 1d (so that we succeed when only + # passed a scalar) + xi = tuple(np.atleast_1d(np.asarray(x, dtype=float)) for x in xi) + except TypeError: + raise TypeError("Input samples must be array-like of floats!") # check for non-empty samples if self._has_zero_samples(tuple(len(x) for x in xi)): @@ -119,10 +114,7 @@ def evaluate(self, *xi: ArrayLike) -> NDArray: # checks on input and outputs self._check_samples_consistency(*xi) self._check_input_dimensionality(xi) - - eval_basis = self._evaluate(*xi) - - return eval_basis + return xi def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: """Evaluate the basis set on a grid of equi-spaced sample points. @@ -134,7 +126,8 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: Parameters ---------- n_samples[0],...,n_samples[n] - The number of samples in each axis of the grid. + The number of samples in each axis of the grid. The length of + n_samples must equal the number of combined bases. Returns ------- @@ -335,24 +328,27 @@ def __init__(self, basis1: Basis, basis2: Basis) -> None: def _check_n_basis_min(self) -> None: pass - def _evaluate(self, *xi: NDArray) -> NDArray: + def evaluate(self, *xi: NDArray) -> NDArray: """ Evaluate the basis at the input samples. Parameters ---------- - xi[0], ..., xi[n] : (number of samples, ) - Tuple of input samples. + xi[0], ..., xi[n] : (n_samples,) + Tuple of input samples, each with the same number of samples. The + number of input arrays must equal the number of combined bases. Returns ------- : The basis function evaluated at the samples, shape (n_samples, n_basis_funcs) + """ + xi = self._check_evaluate_input(*xi) return np.hstack( ( - self._basis1._evaluate(*xi[: self._basis1._n_input_dimensionality]), - self._basis2._evaluate(*xi[self._basis1._n_input_dimensionality :]), + self._basis1.evaluate(*xi[: self._basis1._n_input_dimensionality]), + self._basis2.evaluate(*xi[self._basis1._n_input_dimensionality :]), ) ) @@ -388,24 +384,26 @@ def __init__(self, basis1: Basis, basis2: Basis) -> None: def _check_n_basis_min(self) -> None: pass - def _evaluate(self, *xi: NDArray) -> NDArray: + def evaluate(self, *xi: NDArray) -> NDArray: """ Evaluate the basis at the input samples. Parameters ---------- - xi[0], ..., xi[n] : (number of samples, ) - Tuple of input samples. + xi[0], ..., xi[n] : (n_samples,) + Tuple of input samples, each with the same number of samples. The + number of input arrays must equal the number of combined bases. Returns ------- : The basis function evaluated at the samples, shape (n_samples, n_basis_funcs) """ + xi = self._check_evaluate_input(*xi) return np.array( row_wise_kron( - self._basis1._evaluate(*xi[: self._basis1._n_input_dimensionality]), - self._basis2._evaluate(*xi[self._basis1._n_input_dimensionality :]), + self._basis1.evaluate(*xi[: self._basis1._n_input_dimensionality]), + self._basis2.evaluate(*xi[self._basis1._n_input_dimensionality :]), transpose=False, ) ) @@ -448,7 +446,7 @@ def _generate_knots( Parameters ---------- - sample_pts : (number of samples, ) + sample_pts : (n_samples,) The sample points. perc_low The low percentile value, between [0,1). @@ -531,14 +529,13 @@ class MSplineBasis(SplineBasis): def __init__(self, n_basis_funcs: int, order: int = 2) -> None: super().__init__(n_basis_funcs, order) - def _evaluate(self, sample_pts: NDArray) -> NDArray: + def evaluate(self, sample_pts: NDArray) -> NDArray: """Generate basis functions with given spacing. Parameters ---------- sample_pts : - Spacing for basis functions, holding elements on the interval [min(sample_pts), - max(sample_pts)], shape (number of samples, ) + Spacing for basis functions, shape (n_samples,) Returns ------- @@ -546,6 +543,7 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: Evaluated spline basis functions, shape (n_samples, n_basis_funcs). """ + sample_pts, = self._check_evaluate_input(sample_pts) # add knots if not passed knot_locs = self._generate_knots( sample_pts, perc_low=0.0, perc_high=1.0, is_cyclic=False @@ -559,6 +557,24 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: axis=1, ) + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """Evaluate the M-spline basis set on a grid of equi-spaced sample points. + + Parameters + ---------- + n_samples : + The number of samples. + + Returns + ------- + X : + Array of shape (n_samples,) containing the equi-spaced sample + points where we've evaluated the basis. + basis_funcs : + Raised cosine basis functions, shape (n_samples, n_basis_funcs) + + """ + return super().evaluate_on_grid(n_samples) class BSplineBasis(SplineBasis): """ @@ -589,30 +605,31 @@ class BSplineBasis(SplineBasis): def __init__(self, n_basis_funcs: int, order: int = 2): super().__init__(n_basis_funcs, order=order) - def _evaluate(self, sample_pts: NDArray) -> NDArray: + def evaluate(self, sample_pts: NDArray) -> NDArray: """ Evaluate the B-spline basis functions with given sample points. Parameters ---------- sample_pts : - The sample points at which the B-spline is evaluated. + The sample points at which the B-spline is evaluated, shape (n_samples,). Returns ------- - NDArray + basis_funcs : The basis function evaluated at the samples, shape (n_samples, n_basis_funcs) Raises ------ AssertionError - If the sample points are not within the B-spline knots range unless `outer_ok=True`. + If the sample points are not within the B-spline knots. Notes ----- The evaluation is performed by looping over each element and using `splev` from SciPy to compute the basis values. """ + sample_pts, = self._check_evaluate_input(sample_pts) # add knots knot_locs = self._generate_knots(sample_pts, 0.0, 1.0) @@ -623,6 +640,29 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: return basis_eval + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """Evaluate the B-spline basis set on a grid of equi-spaced sample points. + + Parameters + ---------- + n_samples : + The number of samples. + + Returns + ------- + X : + Array of shape (n_samples,) containing the equi-spaced sample + points where we've evaluated the basis. + basis_funcs : + Raised cosine basis functions, shape (n_samples, n_basis_funcs) + + Notes + ----- + The evaluation is performed by looping over each element and using `splev` from + SciPy to compute the basis values. + """ + return super().evaluate_on_grid(n_samples) + class CyclicBSplineBasis(SplineBasis): """ @@ -653,30 +693,27 @@ def __init__(self, n_basis_funcs: int, order: int = 2): f"order {self.order} specified instead!" ) - def _evaluate(self, sample_pts: NDArray) -> NDArray: - """ - Evaluate the B-spline basis functions with given sample points. + def evaluate(self, sample_pts: NDArray) -> NDArray: + """Evaluate the Cyclic B-spline basis functions with given sample points. Parameters ---------- sample_pts : - The sample points at which the B-spline is evaluated. Must be a tuple of length 1. + The sample points at which the cyclic B-spline is evaluated, shape + (n_samples,). Returns ------- - NDArray + basis_funcs : The basis function evaluated at the samples, shape (n_samples, n_basis_funcs) - Raises - ------ - AssertionError - If the sample points are not within the B-spline knots range unless `outer_ok=True`. - Notes ----- The evaluation is performed by looping over each element and using `splev` from SciPy to compute the basis values. + """ + sample_pts, = self._check_evaluate_input(sample_pts) knot_locs = self._generate_knots(sample_pts, 0.0, 1.0, is_cyclic=True) @@ -710,6 +747,28 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: return basis_eval + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """Evaluate the Cyclic B-spline basis set on a grid of equi-spaced sample points. + + Parameters + ---------- + n_samples : + The number of samples. + + Returns + ------- + X : + Array of shape (n_samples,) containing the equi-spaced sample + points where we've evaluated the basis. + basis_funcs : + Raised cosine basis functions, shape (n_samples, n_basis_funcs) + + Notes + ----- + The evaluation is performed by looping over each element and using `splev` from + SciPy to compute the basis values. + """ + return super().evaluate_on_grid(n_samples) class RaisedCosineBasis(Basis, abc.ABC): def __init__(self, n_basis_funcs: int) -> None: @@ -724,21 +783,21 @@ def _transform_samples(self, sample_pts: NDArray) -> NDArray: Parameters ---------- sample_pts : - The sample points to be transformed, shape (number of samples, ). + The sample points to be transformed, shape (n_samples,). """ pass - def _evaluate(self, sample_pts: NDArray) -> NDArray: + def evaluate(self, sample_pts: NDArray) -> NDArray: """Generate basis functions with given samples. Parameters ---------- sample_pts : (number of samples,) - Spacing for basis functions, holding elements on interval [0, - 1). A good default is - ``nmo.sample_points.raised_cosine_log`` for log spacing (as used in - [2]_) or ``nmo.sample_points.raised_cosine_linear`` for linear - spacing. + Spacing for basis functions, holding elements on interval [0, 1). A + good default is ``nmo.sample_points.raised_cosine_log`` for log + spacing (as used in [2]_) or + ``nmo.sample_points.raised_cosine_linear`` for linear spacing. + Shape (n_samples,). Returns ------- @@ -749,7 +808,9 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: ------ ValueError If the sample provided do not lie in [0,1]. + """ + sample_pts, = self._check_evaluate_input(sample_pts) if any(sample_pts < 0) or any(sample_pts > 1): raise ValueError("Sample points for RaisedCosine basis must lie in [0,1]!") @@ -764,6 +825,25 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: return basis_funcs + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """Evaluate the basis set on a grid of equi-spaced sample points. + + Parameters + ---------- + n_samples : + The number of samples. + + Returns + ------- + X : + Array of shape (n_samples,) containing the equi-spaced sample + points where we've evaluated the basis. + basis_funcs : + Raised cosine basis functions, shape (n_samples, n_basis_funcs) + + """ + return super().evaluate_on_grid(n_samples) + class RaisedCosineBasisLinear(RaisedCosineBasis): """Linearly-spaced raised cosine basis functions used by Pillow et al. [2]_. @@ -794,7 +874,7 @@ def _transform_samples(self, sample_pts: NDArray) -> NDArray: Parameters ---------- sample_pts : - The sample points used for evaluating the splines, shape (number of samples, ) + The sample points used for evaluating the splines, shape (n_samples,) Returns ------- @@ -851,7 +931,7 @@ def _transform_samples(self, sample_pts: NDArray) -> NDArray: Parameters ---------- sample_pts : - The sample points used for evaluating the splines, shape (number of samples, ). + The sample points used for evaluating the splines, shape (n_samples,). Returns ------- @@ -984,20 +1064,23 @@ def _check_sample_size(self, *sample_pts: NDArray): f"but only {sample_pts[0].size} samples provided!" ) - def _evaluate(self, sample_pts: NDArray) -> NDArray: + def evaluate(self, sample_pts: NDArray) -> NDArray: """Generate basis functions with given spacing. Parameters ---------- sample_pts - Spacing for basis functions, holding elements on the interval [0, inf), shape (n_pts,). + Spacing for basis functions, holding elements on the interval [0, + inf), shape (n_samples,). Returns ------- basis_funcs - Evaluated exponentially decaying basis functions, - numerically orthogonalized, shape (number of basis, number of samples). + Evaluated exponentially decaying basis functions, numerically + orthogonalized, shape (n_samples, n_basis_funcs) + """ + sample_pts, = self._check_evaluate_input(sample_pts) self._check_sample_range(sample_pts) self._check_sample_size(sample_pts) # because of how scipy.linalg.orth works, have to create a matrix of @@ -1008,6 +1091,26 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: np.stack([np.exp(-lam * sample_pts) for lam in self._decay_rates], axis=1) ) + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """Evaluate the basis set on a grid of equi-spaced sample points. + + Parameters + ---------- + n_samples : + The number of samples. + + Returns + ------- + X : + Array of shape (n_samples,) containing the equi-spaced sample + points where we've evaluated the basis. + basis_funcs : + Evaluated exponentially decaying basis functions, numerically + orthogonalized, shape (n_samples, n_basis_funcs) + + """ + return super().evaluate_on_grid(n_samples) + def mspline(x: NDArray, k: int, i: int, T: NDArray): """Compute M-spline basis function. @@ -1066,7 +1169,8 @@ def bspline( Parameters ---------- sample_pts : - An array containing sample points for which B-spline basis needs to be evaluated. + An array containing sample points for which B-spline basis needs to be evaluated, + shape (n_samples,) knots : An array containing knots for the B-spline basis. The knots are sorted in ascending order. order :