diff --git a/docs/developers_notes/01-basis_module.md b/docs/developers_notes/01-basis_module.md index ebb2ace5..c0bdf6f8 100644 --- a/docs/developers_notes/01-basis_module.md +++ b/docs/developers_notes/01-basis_module.md @@ -19,9 +19,7 @@ Abstract Class Basis │ │ │ └─ Concrete Subclass CyclicBSplineBasis │ -├─ Abstract Subclass RaisedCosineBasis -│ │ -│ ├─ Concrete Subclass RaisedCosineBasisLinear +├─ Concrete Subclass RaisedCosineBasisLinear │ │ │ └─ Concrete Subclass RaisedCosineBasisLog │ diff --git a/docs/examples/plot_1D_basis_function.py b/docs/examples/plot_1D_basis_function.py index b6b8bf94..37ce1e1a 100644 --- a/docs/examples/plot_1D_basis_function.py +++ b/docs/examples/plot_1D_basis_function.py @@ -69,11 +69,11 @@ # evaluate a log-spaced cosine raised function basis. # Instantiate the basis noting that the `RaisedCosineBasisLog` does not require an `order` parameter -raised_cosine_log = nmo.basis.RaisedCosineBasisLog(n_basis_funcs=10) +raised_cosine_log = nmo.basis.RaisedCosineBasisLog(n_basis_funcs=10, width=1.5, time_scaling=50) # Evaluate the raised cosine basis at the equi-spaced sample points # (same method in all Basis elements) -samples, eval_basis = raised_cosine_log.evaluate_on_grid(1000) +samples, eval_basis = raised_cosine_log.evaluate_on_grid(100) # Plot the evaluated log-spaced raised cosine basis plt.figure() diff --git a/src/nemos/__init__.py b/src/nemos/__init__.py index 28836560..8bcee80f 100644 --- a/src/nemos/__init__.py +++ b/src/nemos/__init__.py @@ -1,11 +1,3 @@ #!/usr/bin/env python3 -from . import ( - basis, - exceptions, - glm, - observation_models, - regularizer, - simulation, - utils, -) +from . import basis, exceptions, glm, observation_models, regularizer, simulation, utils diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 428ac417..44150e68 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -51,7 +51,7 @@ def __init__(self, n_basis_funcs: int) -> None: self._check_n_basis_min() @abc.abstractmethod - def evaluate(self, *xi: NDArray) -> NDArray: + def evaluate(self, *xi: ArrayLike) -> NDArray: """ Evaluate the basis set at the given samples x1,...,xn using the subclass-specific "evaluate" method. @@ -327,7 +327,7 @@ def __init__(self, basis1: Basis, basis2: Basis) -> None: def _check_n_basis_min(self) -> None: pass - def evaluate(self, *xi: NDArray) -> NDArray: + def evaluate(self, *xi: ArrayLike) -> NDArray: """ Evaluate the basis at the input samples. @@ -383,7 +383,7 @@ def __init__(self, basis1: Basis, basis2: Basis) -> None: def _check_n_basis_min(self) -> None: pass - def evaluate(self, *xi: NDArray) -> NDArray: + def evaluate(self, *xi: ArrayLike) -> NDArray: """ Evaluate the basis at the input samples. @@ -505,7 +505,7 @@ def _check_n_basis_min(self) -> None: class MSplineBasis(SplineBasis): - """M-spline 1-dimensional basis functions. + """M-spline[$^1$](#references) 1-dimensional basis functions. Parameters ---------- @@ -519,15 +519,14 @@ class MSplineBasis(SplineBasis): References ---------- - [^1]: - Ramsay, J. O. (1988). Monotone regression splines in action. + 1. Ramsay, J. O. (1988). Monotone regression splines in action. Statistical science, 3(4), 425-441. """ def __init__(self, n_basis_funcs: int, order: int = 2) -> None: super().__init__(n_basis_funcs, order) - def evaluate(self, sample_pts: NDArray) -> NDArray: + def evaluate(self, sample_pts: ArrayLike) -> NDArray: """Generate basis functions with given spacing. Parameters @@ -577,7 +576,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: class BSplineBasis(SplineBasis): """ - B-spline 1-dimensional basis functions. + B-spline[$^1$](#references) 1-dimensional basis functions. Parameters ---------- @@ -596,8 +595,7 @@ class BSplineBasis(SplineBasis): References ---------- - [^2]: - Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques. + 1. Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques. Mathematics and Visualization. Springer, Berlin, Heidelberg. https://doi.org/10.1007/978-3-662-04919-8_5 """ @@ -605,7 +603,7 @@ class BSplineBasis(SplineBasis): def __init__(self, n_basis_funcs: int, order: int = 2): super().__init__(n_basis_funcs, order=order) - def evaluate(self, sample_pts: NDArray) -> NDArray: + def evaluate(self, sample_pts: ArrayLike) -> NDArray: """ Evaluate the B-spline basis functions with given sample points. @@ -693,7 +691,7 @@ def __init__(self, n_basis_funcs: int, order: int = 2): f"order {self.order} specified instead!" ) - def evaluate(self, sample_pts: NDArray) -> NDArray: + def evaluate(self, sample_pts: ArrayLike) -> NDArray: """Evaluate the Cyclic B-spline basis functions with given sample points. Parameters @@ -771,34 +769,68 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return super().evaluate_on_grid(n_samples) -class RaisedCosineBasis(Basis, abc.ABC): - def __init__(self, n_basis_funcs: int) -> None: +class RaisedCosineBasisLinear(Basis): + """Represent linearly-spaced raised cosine basis functions. + + This implementation is based on the cosine bumps used by Pillow et al.[$^1$](#references) + to uniformly tile the internal points of the domain. + + Parameters + ---------- + n_basis_funcs : + The number of basis functions. + width : + Width of the raised cosine. By default, it's set to 2.0. + + References + ---------- + 1. Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., + C. E. (2005). Prediction and decoding of retinal ganglion cell responses + with a probabilistic spiking model. Journal of Neuroscience, 25(47), + 11003–11013. http://dx.doi.org/10.1523/jneurosci.3305-05.2005 + """ + + def __init__(self, n_basis_funcs: int, width: float = 2.0) -> None: super().__init__(n_basis_funcs) self._n_input_dimensionality = 1 + self._check_width(width) + self._width = width - @abc.abstractmethod - def _transform_samples(self, sample_pts: NDArray) -> NDArray: - """ - Abstract method for transforming sample points. + @property + def width(self): + """Return width of the raised cosine.""" + return self._width + + @staticmethod + def _check_width(width: float): + """Validate the width value. Parameters ---------- - sample_pts : - The sample points to be transformed, shape (n_samples,). + width : + The width value to validate. + + Raises + ------ + ValueError + If width <= 1 or 2*width is not a positive integer. Values that do not match + this constraint will result in: + - No overlap between bumps (width < 1). + - Oscillatory behavior when summing the basis elements (2*width not integer). """ - pass + if width <= 1 or (not np.isclose(width * 2, round(2 * width))): + raise ValueError( + f"Invalid raised cosine width. " + f"2*width must be a positive integer, 2*width = {2 * width} instead!" + ) - def evaluate(self, sample_pts: NDArray) -> NDArray: + def evaluate(self, sample_pts: ArrayLike) -> NDArray: """Generate basis functions with given samples. Parameters ---------- - sample_pts : (number of samples,) - Spacing for basis functions, holding elements on interval [0, 1). A - good default is ``nmo.sample_points.raised_cosine_log`` for log - spacing (as used in [2]_) or - ``nmo.sample_points.raised_cosine_linear`` for linear spacing. - Shape (n_samples,). + sample_pts : + Spacing for basis functions, holding elements on interval [0, 1], Shape (number of samples, ). Returns ------- @@ -815,17 +847,34 @@ def evaluate(self, sample_pts: NDArray) -> NDArray: if any(sample_pts < 0) or any(sample_pts > 1): raise ValueError("Sample points for RaisedCosine basis must lie in [0,1]!") - # transform to the proper domain - transform_sample_pts = self._transform_samples(sample_pts) - - shifted_sample_pts = ( - transform_sample_pts[:, None] - - (np.pi * np.arange(self.n_basis_funcs))[None, :] + peaks = self._compute_peaks() + delta = peaks[1] - peaks[0] + # generate a set of shifted cosines, and constrain them to be non-zero + # over a single period, then enforce the codomain to be [0,1], by adding 1 + # and then multiply by 0.5 + basis_funcs = 0.5 * ( + np.cos( + np.clip( + np.pi * (sample_pts[:, None] - peaks[None]) / (delta * self.width), + -np.pi, + np.pi, + ) + ) + + 1 ) - basis_funcs = 0.5 * (np.cos(np.clip(shifted_sample_pts, -np.pi, np.pi)) + 1) return basis_funcs + def _compute_peaks(self): + """ + Compute the location of raised cosine peaks. + + Returns + ------- + Peak locations of each basis element. + """ + return np.linspace(0, 1, self.n_basis_funcs) + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """Evaluate the basis set on a grid of equi-spaced sample points. @@ -845,127 +894,140 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ return super().evaluate_on_grid(n_samples) - -class RaisedCosineBasisLinear(RaisedCosineBasis): - """Linearly-spaced raised cosine basis functions used by Pillow et al. - - These are "cosine bumps" that uniformly tile the space. - - - Parameters - ---------- - n_basis_funcs - Number of basis functions. - - References - ---------- - [^3]: - Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., - C. E. (2005). Prediction and decoding of retinal ganglion cell responses - with a probabilistic spiking model. Journal of Neuroscience, 25(47), - 11003–11013. http://dx.doi.org/10.1523/jneurosci.3305-05.2005 - - """ - - def __init__(self, n_basis_funcs: int) -> None: - super().__init__(n_basis_funcs) - - def _transform_samples(self, sample_pts: NDArray) -> NDArray: - """ - Linearly map the samples from [0,1] to the the [0, (n_basis_funcs - 1) * pi]. - - Parameters - ---------- - sample_pts : - The sample points used for evaluating the splines, shape (n_samples,) - - Returns - ------- - : - A transformed version of the sample points that matches the Raised Cosine basis domain, - shape (number of samples, ). - """ - return sample_pts * np.pi * (self.n_basis_funcs - 1) - def _check_n_basis_min(self) -> None: """Check that the user required enough basis elements. - Check that the number of basis is at least 1. + Check that the number of basis is at least 2. Raises ------ ValueError - If an insufficient number of basis element is requested for the basis type + If n_basis_funcs < 2. """ - if self.n_basis_funcs < 1: + if self.n_basis_funcs < 2: raise ValueError( - f"Object class {self.__class__.__name__} requires >= 1 basis elements. " + f"Object class {self.__class__.__name__} requires >= 2 basis elements. " f"{self.n_basis_funcs} basis elements specified instead" ) -class RaisedCosineBasisLog(RaisedCosineBasis): - """Log-spaced raised cosine basis functions used by Pillow et al. [2]_. +class RaisedCosineBasisLog(RaisedCosineBasisLinear): + """Represent log-spaced raised cosine basis functions. - These are "cosine bumps" that uniformly tile the space. + Similar to `RaisedCosineBasisLinear` but the basis functions are log-spaced. + This implementation is based on the cosine bumps used by Pillow et al.[$^1$](#references) + to uniformly tile the internal points of the domain. Parameters ---------- - n_basis_funcs - Number of basis functions. + n_basis_funcs : + The number of basis functions. + width : + Width of the raised cosine. By default, it's set to 2.0. + enforce_decay_to_zero: + If set to True, the algorithm first constructs a basis with `n_basis_funcs + ceil(width)` elements + and subsequently trims off the extra basis elements. This ensures that the final basis element + decays to 0. References ---------- - .. [2] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., + 1. Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., C. E. (2005). Prediction and decoding of retinal ganglion cell responses with a probabilistic spiking model. Journal of Neuroscience, 25(47), 11003–11013. http://dx.doi.org/10.1523/jneurosci.3305-05.2005 - """ - def __init__(self, n_basis_funcs: int) -> None: - super().__init__(n_basis_funcs) + def __init__( + self, + n_basis_funcs: int, + width: float = 2.0, + time_scaling: float = None, + enforce_decay_to_zero: bool = True, + ) -> None: + super().__init__(n_basis_funcs, width=width) + self.enforce_decay_to_zero = enforce_decay_to_zero + if time_scaling is None: + self._time_scaling = 50.0 + else: + self._check_time_scaling(time_scaling) + self._time_scaling = time_scaling + + @property + def time_scaling(self): + return self._time_scaling - def _transform_samples(self, sample_pts: NDArray) -> NDArray: - """Map the sample domain to log-space. + @staticmethod + def _check_time_scaling(time_scaling): + if time_scaling <= 0: + raise ValueError( + f"Only strictly positive time_scaling are allowed, {time_scaling} provided instead." + ) - Map the equi-spaced samples from [0,1] to log equi-spaced samples [0, (n_basis_funcs - 1) * pi]. + def _transform_samples(self, sample_pts: NDArray) -> NDArray: + """ + Map the sample domain to log-space. Parameters ---------- sample_pts : - The sample points used for evaluating the splines, shape (n_samples,). + Sample points used for evaluating the splines, + shape (n_samples, ). Returns ------- - : - A transformed version of the sample points that matches the Raised Cosine basis domain, - shape (n_sample_points, ). + Transformed version of the sample points that matches the Raised Cosine basis domain, + shape (n_samples, ). """ - return ( - np.power( - 10, - -(np.log10((self.n_basis_funcs - 1) * np.pi) + 1) * sample_pts - + np.log10((self.n_basis_funcs - 1) * np.pi), - ) - - 0.1 + # This log-stretching of the sample axis has the following effect: + # - as the time_scaling tends to 0, the points will be linearly spaced across the whole domain. + # - as the time_scaling tends to inf, basis will be small and dense around 0 and + # progressively larger and less dense towards 1. + log_spaced_pts = np.log(self.time_scaling * sample_pts + 1) / np.log( + self.time_scaling + 1 ) + return log_spaced_pts - def _check_n_basis_min(self) -> None: - """Check that the user required enough basis elements. + def _compute_peaks(self): + """ + Peak location of each log-spaced cosine basis element - Checks that the number of basis is at least 2. + Compute the peak location for the log-spaced raised cosine basis. + Enforcing that the last basis decays to zero is equivalent to + setting the last peak to a value smaller than 1. + + Returns + ------- + Peak locations of each basis element. + + """ + if self.enforce_decay_to_zero: + # compute the last peak location such that the last + # basis element decays to zero at the last sample. + last_peak = 1 - self.width / (self.n_basis_funcs + self.width - 1) + else: + last_peak = 1 + return np.linspace(0, last_peak, self.n_basis_funcs) + + def evaluate(self, sample_pts: ArrayLike) -> NDArray: + """Generate log-spaced raised cosine basis with given samples. + + Parameters + ---------- + sample_pts : + Spacing for basis functions, holding elements on interval [0, 1]. + + Returns + ------- + basis_funcs : + Log-raised cosine basis functions, shape (n_samples, n_basis_funcs). Raises ------ ValueError - If an insufficient number of basis element is requested for the basis type + If the sample provided do not lie in [0,1]. """ - if self.n_basis_funcs < 2: - raise ValueError( - f"Object class {self.__class__.__name__} requires >= 2 basis elements. " - f"{self.n_basis_funcs} basis elements specified instead" - ) + (sample_pts,) = self._check_evaluate_input(sample_pts) + return super().evaluate(self._transform_samples(sample_pts)) class OrthExponentialBasis(Basis): diff --git a/tests/test_basis.py b/tests/test_basis.py index 59f3b2f1..96760840 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -156,9 +156,14 @@ def test_number_of_required_inputs_evaluate(self, n_input): basis_obj = self.cls(n_basis_funcs=5) inputs = [np.linspace(0, 1, 20)] * n_input if n_input == 0: - expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + expectation = pytest.raises( + TypeError, match="evaluate\(\) missing 1 required positional argument" + ) elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) + expectation = pytest.raises( + TypeError, + match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given", + ) else: expectation = does_not_raise() with expectation: @@ -204,14 +209,84 @@ def test_evaluate_on_grid_input_number(self, n_input): basis_obj = self.cls(n_basis_funcs=5) inputs = [10] * n_input if n_input == 0: - expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + expectation = pytest.raises( + TypeError, + match="evaluate_on_grid\(\) missing 1 required positional argument", + ) elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) + expectation = pytest.raises( + TypeError, + match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", + ) else: expectation = does_not_raise() with expectation: basis_obj.evaluate_on_grid(*inputs) + @pytest.mark.parametrize( + "width ,expectation", + [ + (-1, pytest.raises(ValueError, match="Invalid raised cosine width. ")), + (0, pytest.raises(ValueError, match="Invalid raised cosine width. ")), + (0.5, pytest.raises(ValueError, match="Invalid raised cosine width. ")), + (1, pytest.raises(ValueError, match="Invalid raised cosine width. ")), + (1.5, does_not_raise()), + (2, does_not_raise()), + (2.1, pytest.raises(ValueError, match="Invalid raised cosine width. ")), + ], + ) + def test_width_values(self, width, expectation): + """Test allowable widths: integer multiple of 1/2, greater than 1.""" + with expectation: + self.cls(n_basis_funcs=5, width=width) + + @pytest.mark.parametrize("width", [1.5, 2, 2.5]) + def test_decay_to_zero_basis_number_match(self, width): + """Test that the number of basis is preserved.""" + n_basis_funcs = 10 + _, ev = self.cls( + n_basis_funcs=n_basis_funcs, width=width, enforce_decay_to_zero=True + ).evaluate_on_grid(2) + assert ev.shape[1] == n_basis_funcs, ( + "Basis function number mismatch. " + f"Expected {n_basis_funcs}, got {ev.shape[1]} instead!" + ) + + @pytest.mark.parametrize( + "time_scaling ,expectation", + [ + (-1, pytest.raises(ValueError, match="Only strictly positive time_scaling are allowed")), + (0, pytest.raises(ValueError, match="Only strictly positive time_scaling are allowed")), + (0.1, does_not_raise()), + (10, does_not_raise()), + ], + ) + def test_time_scaling_values(self, time_scaling, expectation): + """Test that only positive time_scaling are allowed.""" + with expectation: + self.cls(n_basis_funcs=5, time_scaling=time_scaling) + + def test_time_scaling_property(self): + """Test that larger time_scaling results in larger departures from linearity.""" + time_scaling = [0.1, 10, 100] + n_basis_funcs = 5 + _, lin_ev = basis.RaisedCosineBasisLinear(n_basis_funcs).evaluate_on_grid(100) + corr = np.zeros(len(time_scaling)) + for idx, ts in enumerate(time_scaling): + # set default decay to zero to get comparable basis + basis_log = self.cls( + n_basis_funcs=n_basis_funcs, + time_scaling=ts, + enforce_decay_to_zero=False, + ) + _, log_ev = basis_log.evaluate_on_grid(100) + # compute the correlation + corr[idx] = (lin_ev.flatten() @ log_ev.flatten()) / ( + np.linalg.norm(lin_ev.flatten()) * np.linalg.norm(log_ev.flatten()) + ) + # check that the correlation decreases as time_scale increases + assert np.all(np.diff(corr) < 0), "As time scales increases, deviation from linearity should increase!" + class TestRaisedCosineLinearBasis(BasisFuncsTesting): cls = basis.RaisedCosineBasisLinear @@ -238,7 +313,7 @@ def test_evaluate_input(self, eval_input): @pytest.mark.parametrize( "args, sample_size", - [[{"n_basis_funcs": n_basis}, 100] for n_basis in [1, 2, 10, 100]], + [[{"n_basis_funcs": n_basis}, 100] for n_basis in [2, 10, 100]], ) def test_evaluate_returns_expected_number_of_basis(self, args, sample_size): """ @@ -276,12 +351,12 @@ def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs): """ Verifies that the minimum number of basis functions required (i.e., 1) is enforced. """ - raise_exception = n_basis_funcs < 1 + raise_exception = n_basis_funcs < 2 if raise_exception: with pytest.raises( ValueError, match=f"Object class {self.cls.__name__} " - r"requires >= 1 basis elements\.", + r"requires >= 2 basis elements\.", ): self.cls(n_basis_funcs=n_basis_funcs) else: @@ -312,9 +387,14 @@ def test_number_of_required_inputs_evaluate(self, n_input): basis_obj = self.cls(n_basis_funcs=5) inputs = [np.linspace(0, 1, 20)] * n_input if n_input == 0: - expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + expectation = pytest.raises( + TypeError, match="evaluate\(\) missing 1 required positional argument" + ) elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) + expectation = pytest.raises( + TypeError, + match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given", + ) else: expectation = does_not_raise() with expectation: @@ -360,14 +440,37 @@ def test_evaluate_on_grid_input_number(self, n_input): basis_obj = self.cls(n_basis_funcs=5) inputs = [10] * n_input if n_input == 0: - expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + expectation = pytest.raises( + TypeError, + match="evaluate_on_grid\(\) missing 1 required positional argument", + ) elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) + expectation = pytest.raises( + TypeError, + match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", + ) else: expectation = does_not_raise() with expectation: basis_obj.evaluate_on_grid(*inputs) + @pytest.mark.parametrize( + "width ,expectation", + [ + (-1, pytest.raises(ValueError, match="Invalid raised cosine width. ")), + (0, pytest.raises(ValueError, match="Invalid raised cosine width. ")), + (0.5, pytest.raises(ValueError, match="Invalid raised cosine width. ")), + (1, pytest.raises(ValueError, match="Invalid raised cosine width. ")), + (1.5, does_not_raise()), + (2, does_not_raise()), + (2.1, pytest.raises(ValueError, match="Invalid raised cosine width. ")), + ], + ) + def test_width_values(self, width, expectation): + """Test allowable widths: integer multiple of 1/2, greater than 1.""" + with expectation: + self.cls(n_basis_funcs=5, width=width) + class TestMSplineBasis(BasisFuncsTesting): cls = basis.MSplineBasis @@ -466,9 +569,14 @@ def test_number_of_required_inputs_evaluate(self, n_input): basis_obj = self.cls(n_basis_funcs=5, order=3) inputs = [np.linspace(0, 1, 20)] * n_input if n_input == 0: - expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + expectation = pytest.raises( + TypeError, match="evaluate\(\) missing 1 required positional argument" + ) elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) + expectation = pytest.raises( + TypeError, + match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given", + ) else: expectation = does_not_raise() with expectation: @@ -514,9 +622,15 @@ def test_evaluate_on_grid_input_number(self, n_input): basis_obj = self.cls(n_basis_funcs=5, order=3) inputs = [10] * n_input if n_input == 0: - expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + expectation = pytest.raises( + TypeError, + match="evaluate_on_grid\(\) missing 1 required positional argument", + ) elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) + expectation = pytest.raises( + TypeError, + match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", + ) else: expectation = does_not_raise() with expectation: @@ -538,7 +652,8 @@ def test_non_empty_samples(self, samples): self.cls(5, decay_rates=np.arange(1, 6)).evaluate(samples) @pytest.mark.parametrize( - "eval_input", [0, [0] * 6, (0,) * 6, np.array([0] * 6), jax.numpy.array([0] * 6)] + "eval_input", + [0, [0] * 6, (0,) * 6, np.array([0] * 6), jax.numpy.array([0] * 6)], ) def test_evaluate_input(self, eval_input): """ @@ -547,7 +662,10 @@ def test_evaluate_input(self, eval_input): basis_obj = self.cls(n_basis_funcs=5, decay_rates=np.arange(1, 6)) if isinstance(eval_input, int): # OrthExponentialBasis is special -- cannot accept int input - with pytest.raises(ValueError, match="OrthExponentialBasis requires at least as many samples"): + with pytest.raises( + ValueError, + match="OrthExponentialBasis requires at least as many samples", + ): basis_obj.evaluate(eval_input) else: basis_obj.evaluate(eval_input) @@ -626,9 +744,14 @@ def test_number_of_required_inputs_evaluate(self, n_input): basis_obj = self.cls(n_basis_funcs=5, decay_rates=np.arange(1, 6)) inputs = [np.linspace(0, 1, 20)] * n_input if n_input == 0: - expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + expectation = pytest.raises( + TypeError, match="evaluate\(\) missing 1 required positional argument" + ) elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) + expectation = pytest.raises( + TypeError, + match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given", + ) else: expectation = does_not_raise() with expectation: @@ -673,9 +796,15 @@ def test_evaluate_on_grid_input_number(self, n_input): basis_obj = self.cls(n_basis_funcs=5, decay_rates=np.arange(1, 6)) inputs = [10] * n_input if n_input == 0: - expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + expectation = pytest.raises( + TypeError, + match="evaluate_on_grid\(\) missing 1 required positional argument", + ) elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) + expectation = pytest.raises( + TypeError, + match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", + ) else: expectation = does_not_raise() with expectation: @@ -829,9 +958,14 @@ def test_number_of_required_inputs_evaluate(self, n_input): basis_obj = self.cls(n_basis_funcs=5, order=3) inputs = [np.linspace(0, 1, 20)] * n_input if n_input == 0: - expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + expectation = pytest.raises( + TypeError, match="evaluate\(\) missing 1 required positional argument" + ) elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) + expectation = pytest.raises( + TypeError, + match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given", + ) else: expectation = does_not_raise() with expectation: @@ -881,9 +1015,15 @@ def test_evaluate_on_grid_input_number(self, n_input): basis_obj = self.cls(n_basis_funcs=5, order=3) inputs = [10] * n_input if n_input == 0: - expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + expectation = pytest.raises( + TypeError, + match="evaluate_on_grid\(\) missing 1 required positional argument", + ) elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) + expectation = pytest.raises( + TypeError, + match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", + ) else: expectation = does_not_raise() with expectation: @@ -1020,9 +1160,14 @@ def test_number_of_required_inputs_evaluate(self, n_input): basis_obj = self.cls(n_basis_funcs=5, order=3) inputs = [np.linspace(0, 1, 20)] * n_input if n_input == 0: - expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + expectation = pytest.raises( + TypeError, match="evaluate\(\) missing 1 required positional argument" + ) elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) + expectation = pytest.raises( + TypeError, + match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given", + ) else: expectation = does_not_raise() with expectation: @@ -1072,9 +1217,15 @@ def test_evaluate_on_grid_input_number(self, n_input): basis_obj = self.cls(n_basis_funcs=5, order=3) inputs = [10] * n_input if n_input == 0: - expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + expectation = pytest.raises( + TypeError, + match="evaluate_on_grid\(\) missing 1 required positional argument", + ) elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) + expectation = pytest.raises( + TypeError, + match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", + ) else: expectation = does_not_raise() with expectation: @@ -1238,10 +1389,14 @@ def test_number_of_required_inputs_evaluate( basis_a_obj = self.instantiate_basis(n_basis_a, basis_a) basis_b_obj = self.instantiate_basis(n_basis_b, basis_b) basis_obj = basis_a_obj + basis_b_obj - required_dim = basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality + required_dim = ( + basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality + ) inputs = [np.linspace(0, 1, 20)] * n_input if n_input != required_dim: - expectation = pytest.raises(TypeError, match="Input dimensionality mismatch.") + expectation = pytest.raises( + TypeError, match="Input dimensionality mismatch." + ) else: expectation = does_not_raise() with expectation: @@ -1320,9 +1475,13 @@ def test_evaluate_on_grid_input_number( basis_b_obj = self.instantiate_basis(n_basis_b, basis_b) basis_obj = basis_a_obj + basis_b_obj inputs = [20] * n_input - required_dim = basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality + required_dim = ( + basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality + ) if n_input != required_dim: - expectation = pytest.raises(TypeError, match="Input dimensionality mismatch.") + expectation = pytest.raises( + TypeError, match="Input dimensionality mismatch." + ) else: expectation = does_not_raise() with expectation: @@ -1395,7 +1554,7 @@ def test_evaluate_returns_expected_number_of_basis( f"The first dimension of the evaluated basis is {eval_basis.shape[1]}", ) - @pytest.mark.parametrize("sample_size", [6, 30, 35]) + @pytest.mark.parametrize("sample_size", [12, 30, 35]) @pytest.mark.parametrize("n_basis_a", [5, 6]) @pytest.mark.parametrize("n_basis_b", [5, 6]) @pytest.mark.parametrize( @@ -1446,10 +1605,14 @@ def test_number_of_required_inputs_evaluate( basis_a_obj = self.instantiate_basis(n_basis_a, basis_a) basis_b_obj = self.instantiate_basis(n_basis_b, basis_b) basis_obj = basis_a_obj * basis_b_obj - required_dim = basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality + required_dim = ( + basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality + ) inputs = [np.linspace(0, 1, 20)] * n_input if n_input != required_dim: - expectation = pytest.raises(TypeError, match="Input dimensionality mismatch.") + expectation = pytest.raises( + TypeError, match="Input dimensionality mismatch." + ) else: expectation = does_not_raise() with expectation: @@ -1528,9 +1691,13 @@ def test_evaluate_on_grid_input_number( basis_b_obj = self.instantiate_basis(n_basis_b, basis_b) basis_obj = basis_a_obj * basis_b_obj inputs = [20] * n_input - required_dim = basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality + required_dim = ( + basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality + ) if n_input != required_dim: - expectation = pytest.raises(TypeError, match="Input dimensionality mismatch.") + expectation = pytest.raises( + TypeError, match="Input dimensionality mismatch." + ) else: expectation = does_not_raise() with expectation: