From 1f1e95ba994dca1478cf87d3a63042ea85adc732 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 13 Oct 2023 12:06:41 -0400 Subject: [PATCH 01/33] added tests --- src/neurostatslib/basis.py | 244 ++++++++++++++++++++++--------------- tests/test_basis.py | 70 ++++++++++- 2 files changed, 210 insertions(+), 104 deletions(-) diff --git a/src/neurostatslib/basis.py b/src/neurostatslib/basis.py index ee7ff73e..17129bf7 100644 --- a/src/neurostatslib/basis.py +++ b/src/neurostatslib/basis.py @@ -711,34 +711,82 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: return basis_eval -class RaisedCosineBasis(Basis, abc.ABC): - def __init__(self, n_basis_funcs: int) -> None: +class RaisedCosineBasisLinear(Basis): + """Represent linearly-spaced raised cosine basis functions. + + This implementation is based on the cosine bumps used by Pillow et al. [2] + to uniformly tile the domain (if alpha = 1) or the internal points of the domain + (if alpha > 1). + + Parameters + ---------- + n_basis_funcs : + The number of basis functions. + alpha : + Width of the raised cosine. By default, it's set to 1.0. + + Attributes + ---------- + n_basis_funcs : + The number of basis functions. + alpha : + Width of the raised cosine. By default, it's set to 1.0. + + References + ---------- + [2] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., + C. E. (2005). Prediction and decoding of retinal ganglion cell responses + with a probabilistic spiking model. Journal of Neuroscience, 25(47), + 11003–11013. http://dx.doi.org/10.1523/jneurosci.3305-05.2005 + """ + + def __init__(self, n_basis_funcs: int, alpha: float = 1.0) -> None: super().__init__(n_basis_funcs) self._n_input_dimensionality = 1 + self._check_alpha(alpha) + self._alpha = alpha - @abc.abstractmethod - def _transform_samples(self, sample_pts: NDArray) -> NDArray: - """ - Abstract method for transforming sample points. + @property + def alpha(self): + """Return width of the raised cosine.""" + return self._alpha + + @alpha.setter + def alpha(self, alpha: float): + """Check and set width of the raised cosine.""" + self._check_alpha(alpha) + self._alpha = alpha + + @staticmethod + def _check_alpha(alpha: float): + """Validate the width value. Parameters ---------- - sample_pts : - The sample points to be transformed, shape (number of samples, ). + alpha : + The width value to validate. + + Raises + ------ + ValueError + If alpha < 1 or 2*alpha is not a positive integer. Values that do not match + this constraint will result in: + - No overlap between bumps (alpha < 1). + - Oscillatory behavior when summing the basis elements (2*alpha not integer). """ - pass + if alpha < 1 or (not np.isclose(alpha * 2, round(2*alpha))): + raise ValueError( + f"Invalid raised cosine width. " + f"2*alpha must be a positive integer, 2*alpha = {2*alpha} instead!" + ) def _evaluate(self, sample_pts: NDArray) -> NDArray: """Generate basis functions with given samples. Parameters ---------- - sample_pts : (number of samples,) - Spacing for basis functions, holding elements on interval [0, - 1). A good default is - ``nsl.sample_points.raised_cosine_log`` for log spacing (as used in - [2]_) or ``nsl.sample_points.raised_cosine_linear`` for linear - spacing. + sample_pts : + Spacing for basis functions, holding elements on interval [0, 1], Shape (number of samples, ). Returns ------- @@ -753,136 +801,132 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: if any(sample_pts < 0) or any(sample_pts > 1): raise ValueError("Sample points for RaisedCosine basis must lie in [0,1]!") - # transform to the proper domain - transform_sample_pts = self._transform_samples(sample_pts) - - shifted_sample_pts = ( - transform_sample_pts[:, None] - - (np.pi * np.arange(self.n_basis_funcs))[None, :] + peaks = np.linspace(0, 1, self.n_basis_funcs) + delta = peaks[1] - peaks[0] + basis_funcs = 0.5 * ( + np.cos( + np.clip( + np.pi * (sample_pts[:, None] - peaks[None]) / (delta * self.alpha), + -np.pi, + np.pi, + ) + ) + + 1 ) - basis_funcs = 0.5 * (np.cos(np.clip(shifted_sample_pts, -np.pi, np.pi)) + 1) return basis_funcs - -class RaisedCosineBasisLinear(RaisedCosineBasis): - """Linearly-spaced raised cosine basis functions used by Pillow et al. [2]_. - - These are "cosine bumps" that uniformly tile the space. - - Parameters - ---------- - n_basis_funcs - Number of basis functions. - - References - ---------- - .. [2] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., - C. E. (2005). Prediction and decoding of retinal ganglion cell responses - with a probabilistic spiking model. Journal of Neuroscience, 25(47), - 11003–11013. http://dx.doi.org/10.1523/jneurosci.3305-05.2005 - - """ - - def __init__(self, n_basis_funcs: int) -> None: - super().__init__(n_basis_funcs) - - def _transform_samples(self, sample_pts: NDArray) -> NDArray: - """ - Linearly map the samples from [0,1] to the the [0, (n_basis_funcs - 1) * pi]. - - Parameters - ---------- - sample_pts : - The sample points used for evaluating the splines, shape (number of samples, ) - - Returns - ------- - : - A transformed version of the sample points that matches the Raised Cosine basis domain, - shape (number of samples, ). - """ - return sample_pts * np.pi * (self.n_basis_funcs - 1) - def _check_n_basis_min(self) -> None: """Check that the user required enough basis elements. - Check that the number of basis is at least 1. + Check that the number of basis is at least 2. Raises ------ ValueError - If an insufficient number of basis element is requested for the basis type + If n_basis_funcs < 2. """ - if self.n_basis_funcs < 1: + if self.n_basis_funcs < 2: raise ValueError( - f"Object class {self.__class__.__name__} requires >= 1 basis elements. " + f"Object class {self.__class__.__name__} requires >= 2 basis elements. " f"{self.n_basis_funcs} basis elements specified instead" ) -class RaisedCosineBasisLog(RaisedCosineBasis): - """Log-spaced raised cosine basis functions used by Pillow et al. [2]_. +class RaisedCosineBasisLog(RaisedCosineBasisLinear): + """Represent log-spaced raised cosine basis functions. - These are "cosine bumps" that uniformly tile the space. + Similar to `RaisedCosineBasisLinear` but the basis functions are log-spaced. + This implementation is based on the cosine bumps used by Pillow et al. [2] + to uniformly tile the domain (if alpha = 1) or the internal points of the domain + (if alpha > 1). Parameters ---------- - n_basis_funcs - Number of basis functions. + n_basis_funcs : + The number of basis functions. + alpha : + Width of the raised cosine. By default, it's set to 1.0. + remove_last_basis: + If True, removes the last basis element so that the basis ends in zero. + + Attributes + ---------- + n_basis_funcs : + The number of basis functions. + alpha : + Width of the raised cosine. By default, it's set to 1.0. + extend_and_trim_last: + If set to True, the algorithm first constructs a basis with `n_basis_funcs + 1` elements + and subsequently trims off the last basis element. This ensures that the final basis element + concludes at a value of 0 instead of 1. References ---------- - .. [2] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., + [2] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., C. E. (2005). Prediction and decoding of retinal ganglion cell responses with a probabilistic spiking model. Journal of Neuroscience, 25(47), 11003–11013. http://dx.doi.org/10.1523/jneurosci.3305-05.2005 - """ - def __init__(self, n_basis_funcs: int) -> None: - super().__init__(n_basis_funcs) + def __init__(self, n_basis_funcs: int, alpha: float = 1.0, extend_and_trim_last: bool = True) -> None: + super().__init__(n_basis_funcs, alpha=alpha) + self.extend_and_trim_last = extend_and_trim_last def _transform_samples(self, sample_pts: NDArray) -> NDArray: - """Map the sample domain to log-space. - - Map the equi-spaced samples from [0,1] to log equi-spaced samples [0, (n_basis_funcs - 1) * pi]. + """ + Map the sample domain to log-space. Parameters ---------- - sample_pts : - The sample points used for evaluating the splines, shape (number of samples, ). + sample_pts : NDArray + Sample points used for evaluating the splines, + shape (n_samples, ). Returns ------- - : - A transformed version of the sample points that matches the Raised Cosine basis domain, - shape (n_sample_points, ). + NDArray + Transformed version of the sample points that matches the Raised Cosine basis domain, + shape (n_samples, ). """ - return ( - np.power( - 10, - -(np.log10((self.n_basis_funcs - 1) * np.pi) + 1) * sample_pts - + np.log10((self.n_basis_funcs - 1) * np.pi), - ) - - 0.1 - ) + # if equi-spaced samples, this is equivalent to + # log_spaced_pts = np.logspace( + # np.log10((self.n_basis_funcs - 1) * np.pi), + # -1, + # sample_pts.shape[0] + # ) - 0.1 + # log_spaced_pts = log_spaced_pts / (np.pi * (self.n_basis_funcs - 1)) + base = np.pi * (self.n_basis_funcs - 1) * 10 + log_spaced_pts = base ** (-sample_pts) - 1 / base + return log_spaced_pts - def _check_n_basis_min(self) -> None: - """Check that the user required enough basis elements. + def _evaluate(self, sample_pts: NDArray) -> NDArray: + """Generate log-spaced raised cosine basis with given samples. + + Parameters + ---------- + sample_pts : + Spacing for basis functions, holding elements on interval [0, 1]. - Checks that the number of basis is at least 2. + Returns + ------- + basis_funcs : + Log-raised cosine basis functions, shape (n_samples, n_basis_funcs). Raises ------ ValueError - If an insufficient number of basis element is requested for the basis type + If the sample provided do not lie in [0,1]. """ - if self.n_basis_funcs < 2: - raise ValueError( - f"Object class {self.__class__.__name__} requires >= 2 basis elements. " - f"{self.n_basis_funcs} basis elements specified instead" - ) + if not self.extend_and_trim_last: + eval_basis = super()._evaluate(self._transform_samples(sample_pts))[:, ::-1] + else: + # temporarily add a basis element + self.n_basis_funcs += 1 + eval_basis = super()._evaluate(self._transform_samples(sample_pts))[:, ::-1] + eval_basis = eval_basis[..., :-1] + self.n_basis_funcs -= 1 + return eval_basis class OrthExponentialBasis(Basis): diff --git a/tests/test_basis.py b/tests/test_basis.py index 7f31fe2e..1a4efb57 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -188,6 +188,37 @@ def test_evaluate_on_grid_meshgrid_size(self, sample_size): grid, _ = basis_obj.evaluate_on_grid(sample_size) assert grid.shape[0] == sample_size + @pytest.mark.parametrize( + "alpha", + [-1, 0, 0.5, 1, 1.5, 1.75, 2, 2.5], + ) + def test_check_cosine_widths_of_basis(self, alpha): + """ + Verifies that the evaluate() method returns the expected number of basis functions. + """ + raise_exception = alpha not in [1, 1.5, 2, 2.5] + if raise_exception: + with pytest.raises(ValueError, match="Invalid raised cosine width"): + self.cls(n_basis_funcs=10, alpha=alpha) + else: + self.cls(n_basis_funcs=10, alpha=alpha) + + @pytest.mark.parametrize( + "alpha", + [-1, 0, 0.5, 1, 1.5, 1.75, 2, 2.5], + ) + def test_check_cosine_widths_of_basis_from_setter(self, alpha): + """ + Verifies that the evaluate() method returns the expected number of basis functions. + """ + raise_exception = alpha not in [1, 1.5, 2, 2.5] + basis_obj = self.cls(n_basis_funcs=10) + if raise_exception: + with pytest.raises(ValueError, match="Invalid raised cosine width"): + basis_obj.alpha = alpha + else: + basis_obj.alpha = alpha + @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) def test_evaluate_on_grid_basis_size(self, sample_size): """ @@ -255,7 +286,7 @@ def test_input_to_evaluate_is_arraylike(self, arraylike): @pytest.mark.parametrize( "args, sample_size", - [[{"n_basis_funcs": n_basis}, 100] for n_basis in [1, 2, 10, 100]], + [[{"n_basis_funcs": n_basis}, 100] for n_basis in [2, 10, 100]], ) def test_evaluate_returns_expected_number_of_basis(self, args, sample_size): """ @@ -271,6 +302,37 @@ def test_evaluate_returns_expected_number_of_basis(self, args, sample_size): ) return + @pytest.mark.parametrize( + "alpha", + [-1, 0, 0.5, 1, 1.5, 1.75, 2, 2.5], + ) + def test_check_cosine_widths_of_basis(self, alpha): + """ + Verifies that the evaluate() method returns the expected number of basis functions. + """ + raise_exception = alpha not in [1, 1.5, 2, 2.5] + if raise_exception: + with pytest.raises(ValueError, match="Invalid raised cosine width"): + self.cls(n_basis_funcs=10, alpha=alpha) + else: + self.cls(n_basis_funcs=10, alpha=alpha) + + @pytest.mark.parametrize( + "alpha", + [-1, 0, 0.5, 1, 1.5, 1.75, 2, 2.5], + ) + def test_check_cosine_widths_of_basis_from_setter(self, alpha): + """ + Verifies that the evaluate() method returns the expected number of basis functions. + """ + raise_exception = alpha not in [1, 1.5, 2, 2.5] + basis_obj = self.cls(n_basis_funcs=10) + if raise_exception: + with pytest.raises(ValueError, match="Invalid raised cosine width"): + basis_obj.alpha = alpha + else: + basis_obj.alpha = alpha + @pytest.mark.parametrize("sample_size", [100, 1000]) @pytest.mark.parametrize("n_basis_funcs", [2, 10, 100]) def test_sample_size_of_evaluate_matches_that_of_input( @@ -293,10 +355,10 @@ def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs): """ Verifies that the minimum number of basis functions required (i.e., 1) is enforced. """ - raise_exception = n_basis_funcs < 1 + raise_exception = n_basis_funcs < 2 if raise_exception: with pytest.raises(ValueError, match=f"Object class {self.cls.__name__} " - r"requires >= 1 basis elements\."): + r"requires >= 2 basis elements\."): self.cls(n_basis_funcs=n_basis_funcs) else: self.cls(n_basis_funcs=n_basis_funcs) @@ -1472,7 +1534,7 @@ def test_evaluate_returns_expected_number_of_basis( f"The first dimension of the evaluated basis is {eval_basis.shape[1]}", ) - @pytest.mark.parametrize("sample_size", [6, 30, 35]) + @pytest.mark.parametrize("sample_size", [12, 30, 35]) @pytest.mark.parametrize("n_basis_a", [5, 6]) @pytest.mark.parametrize("n_basis_b", [5, 6]) @pytest.mark.parametrize( From ccb1c5efe07ca427ae3195e1ab2f900479ac98fd Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Sat, 14 Oct 2023 08:39:24 -0500 Subject: [PATCH 02/33] linted --- src/neurostatslib/basis.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/neurostatslib/basis.py b/src/neurostatslib/basis.py index 17129bf7..9b5227d2 100644 --- a/src/neurostatslib/basis.py +++ b/src/neurostatslib/basis.py @@ -774,7 +774,7 @@ def _check_alpha(alpha: float): - No overlap between bumps (alpha < 1). - Oscillatory behavior when summing the basis elements (2*alpha not integer). """ - if alpha < 1 or (not np.isclose(alpha * 2, round(2*alpha))): + if alpha < 1 or (not np.isclose(alpha * 2, round(2 * alpha))): raise ValueError( f"Invalid raised cosine width. " f"2*alpha must be a positive integer, 2*alpha = {2*alpha} instead!" @@ -869,7 +869,9 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear): 11003–11013. http://dx.doi.org/10.1523/jneurosci.3305-05.2005 """ - def __init__(self, n_basis_funcs: int, alpha: float = 1.0, extend_and_trim_last: bool = True) -> None: + def __init__( + self, n_basis_funcs: int, alpha: float = 1.0, extend_and_trim_last: bool = True + ) -> None: super().__init__(n_basis_funcs, alpha=alpha) self.extend_and_trim_last = extend_and_trim_last From 6bd116ae88c67fc6206d93ce2e6137dc8637b1ed Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Sat, 14 Oct 2023 09:48:52 -0500 Subject: [PATCH 03/33] run on mac --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4bb79939..454c1628 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,7 +13,7 @@ jobs: tox: strategy: matrix: - os: [ubuntu-latest] #[ubuntu-latest, macos-latest, windows-latest] + os: [macos-latest] #[ubuntu-latest, macos-latest, windows-latest] python-version: ['3.10'] #['3.8', '3.9', '3.10'] runs-on: ${{ matrix.os }} steps: From 5e4122a29a99eafd76dd5e7d2138cd063356fab9 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Sat, 14 Oct 2023 09:55:55 -0500 Subject: [PATCH 04/33] test on linux and mac --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 454c1628..8aadb23c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,7 +13,7 @@ jobs: tox: strategy: matrix: - os: [macos-latest] #[ubuntu-latest, macos-latest, windows-latest] + os: [ubuntu-latest, macos-latest] #[ubuntu-latest, macos-latest, windows-latest] python-version: ['3.10'] #['3.8', '3.9', '3.10'] runs-on: ${{ matrix.os }} steps: From df6925543821e494de6b47bbf508aa6bd961d63e Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 21 Nov 2023 00:18:46 -0500 Subject: [PATCH 05/33] clip -1 --- src/neurostatslib/basis.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/neurostatslib/basis.py b/src/neurostatslib/basis.py index 9b5227d2..8b6eae27 100644 --- a/src/neurostatslib/basis.py +++ b/src/neurostatslib/basis.py @@ -746,6 +746,7 @@ def __init__(self, n_basis_funcs: int, alpha: float = 1.0) -> None: self._check_alpha(alpha) self._alpha = alpha + @property def alpha(self): """Return width of the raised cosine.""" @@ -870,10 +871,15 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear): """ def __init__( - self, n_basis_funcs: int, alpha: float = 1.0, extend_and_trim_last: bool = True + self, + n_basis_funcs: int, + alpha: float = 1.0, + extend_and_trim_last: bool = True, + clip_first: bool = True ) -> None: super().__init__(n_basis_funcs, alpha=alpha) self.extend_and_trim_last = extend_and_trim_last + self._clip_first = clip_first def _transform_samples(self, sample_pts: NDArray) -> NDArray: """ @@ -928,6 +934,9 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: eval_basis = super()._evaluate(self._transform_samples(sample_pts))[:, ::-1] eval_basis = eval_basis[..., :-1] self.n_basis_funcs -= 1 + if self._clip_first: + idx = np.argmin(np.abs(eval_basis[:, 0] - 1)) + eval_basis[:idx, 0] = 1 return eval_basis From 47cfaeeca4ab7c5fcbf5394365d7fadcbc5d993d Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 21 Nov 2023 16:18:34 -0500 Subject: [PATCH 06/33] changed default --- src/neurostatslib/basis.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/neurostatslib/basis.py b/src/neurostatslib/basis.py index 8b6eae27..1af34f44 100644 --- a/src/neurostatslib/basis.py +++ b/src/neurostatslib/basis.py @@ -875,7 +875,7 @@ def __init__( n_basis_funcs: int, alpha: float = 1.0, extend_and_trim_last: bool = True, - clip_first: bool = True + clip_first: bool = False ) -> None: super().__init__(n_basis_funcs, alpha=alpha) self.extend_and_trim_last = extend_and_trim_last @@ -937,6 +937,7 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: if self._clip_first: idx = np.argmin(np.abs(eval_basis[:, 0] - 1)) eval_basis[:idx, 0] = 1 + return eval_basis From 56c51db348098d02f2aed99b20179185728cb79a Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 30 Nov 2023 14:37:54 -0500 Subject: [PATCH 07/33] reparametrization --- docs/examples/plot_1D_basis_function.py | 2 +- src/{neurostatslib => nemos}/basis.py | 25 ++++++++++++------------- tests/test_basis.py | 2 +- 3 files changed, 14 insertions(+), 15 deletions(-) rename src/{neurostatslib => nemos}/basis.py (98%) diff --git a/docs/examples/plot_1D_basis_function.py b/docs/examples/plot_1D_basis_function.py index b6b8bf94..1112e5c3 100644 --- a/docs/examples/plot_1D_basis_function.py +++ b/docs/examples/plot_1D_basis_function.py @@ -69,7 +69,7 @@ # evaluate a log-spaced cosine raised function basis. # Instantiate the basis noting that the `RaisedCosineBasisLog` does not require an `order` parameter -raised_cosine_log = nmo.basis.RaisedCosineBasisLog(n_basis_funcs=10) +raised_cosine_log = nmo.basis.RaisedCosineBasisLog(n_basis_funcs=21, alpha=1.5) # Evaluate the raised cosine basis at the equi-spaced sample points # (same method in all Basis elements) diff --git a/src/neurostatslib/basis.py b/src/nemos/basis.py similarity index 98% rename from src/neurostatslib/basis.py rename to src/nemos/basis.py index 1af34f44..5594a20f 100644 --- a/src/neurostatslib/basis.py +++ b/src/nemos/basis.py @@ -12,7 +12,7 @@ from numpy.typing import ArrayLike, NDArray from scipy.interpolate import splev -from neurostatslib.utils import row_wise_kron +from .utils import row_wise_kron __all__ = [ "MSplineBasis", @@ -64,7 +64,7 @@ def _evaluate(self, *xi: NDArray) -> NDArray: pass @staticmethod - def _get_samples(*n_samples: int) -> Generator[NDArray, None, None]: + def _get_samples(*n_samples: int) -> Generator[NDArray]: """Get equi-spaced samples for all the input dimensions. This will be used to evaluate the basis on a grid of @@ -520,12 +520,11 @@ class MSplineBasis(SplineBasis): at each interior knot. The higher this number, the smoother the basis representation will be. - References ---------- - .. [1] Ramsay, J. O. (1988). Monotone regression splines in action. - Statistical science, 3(4), 425-441. - + [^1]: + Ramsay, J. O. (1988). Monotone regression splines in action. + Statistical science, 3(4), 425-441. """ def __init__(self, n_basis_funcs: int, order: int = 2) -> None: @@ -581,8 +580,9 @@ class BSplineBasis(SplineBasis): References ---------- - ..[2] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques. - Mathematics and Visualization. Springer, Berlin, Heidelberg. https://doi.org/10.1007/978-3-662-04919-8_5 + [^2]: + Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques. + Mathematics and Visualization. Springer, Berlin, Heidelberg. https://doi.org/10.1007/978-3-662-04919-8_5 """ @@ -613,7 +613,6 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: The evaluation is performed by looping over each element and using `splev` from SciPy to compute the basis values. """ - # add knots knot_locs = self._generate_knots(sample_pts, 0.0, 1.0) @@ -677,7 +676,6 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: The evaluation is performed by looping over each element and using `splev` from SciPy to compute the basis values. """ - knot_locs = self._generate_knots(sample_pts, 0.0, 1.0, is_cyclic=True) # for cyclic, do not repeat knots @@ -930,10 +928,11 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: eval_basis = super()._evaluate(self._transform_samples(sample_pts))[:, ::-1] else: # temporarily add a basis element - self.n_basis_funcs += 1 + n_trim = int(np.round(self.alpha)) + self.n_basis_funcs += n_trim eval_basis = super()._evaluate(self._transform_samples(sample_pts))[:, ::-1] - eval_basis = eval_basis[..., :-1] - self.n_basis_funcs -= 1 + eval_basis = eval_basis[..., :-n_trim] + self.n_basis_funcs -= n_trim if self._clip_first: idx = np.argmin(np.abs(eval_basis[:, 0] - 1)) eval_basis[:idx, 0] = 1 diff --git a/tests/test_basis.py b/tests/test_basis.py index f0e2067d..7fb78e43 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -255,7 +255,7 @@ def test_input_to_evaluate_is_arraylike(self, arraylike): @pytest.mark.parametrize( "args, sample_size", - [[{"n_basis_funcs": n_basis}, 100] for n_basis in [1, 2, 10, 100]], + [[{"n_basis_funcs": n_basis}, 100] for n_basis in [2, 10, 100]], ) def test_evaluate_returns_expected_number_of_basis(self, args, sample_size): """ From f2b270eb432eac1ed89fff702fcc4a87b660b299 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 30 Nov 2023 14:40:16 -0500 Subject: [PATCH 08/33] linted --- src/nemos/basis.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 5594a20f..35475081 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -744,7 +744,6 @@ def __init__(self, n_basis_funcs: int, alpha: float = 1.0) -> None: self._check_alpha(alpha) self._alpha = alpha - @property def alpha(self): """Return width of the raised cosine.""" @@ -873,7 +872,7 @@ def __init__( n_basis_funcs: int, alpha: float = 1.0, extend_and_trim_last: bool = True, - clip_first: bool = False + clip_first: bool = False, ) -> None: super().__init__(n_basis_funcs, alpha=alpha) self.extend_and_trim_last = extend_and_trim_last From 8c9b1527900a53eca4c166cf5055a11c046ce7e8 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 30 Nov 2023 14:46:52 -0500 Subject: [PATCH 09/33] updated param names --- src/nemos/basis.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 35475081..a243f615 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -841,23 +841,17 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear): Parameters ---------- - n_basis_funcs : - The number of basis functions. - alpha : - Width of the raised cosine. By default, it's set to 1.0. - remove_last_basis: - If True, removes the last basis element so that the basis ends in zero. - - Attributes - ---------- n_basis_funcs : The number of basis functions. alpha : Width of the raised cosine. By default, it's set to 1.0. extend_and_trim_last: - If set to True, the algorithm first constructs a basis with `n_basis_funcs + 1` elements - and subsequently trims off the last basis element. This ensures that the final basis element - concludes at a value of 0 instead of 1. + If set to True, the algorithm first constructs a basis with `n_basis_funcs + 1` elements + and subsequently trims off the last basis element. This ensures that the final basis element + concludes at a value of 0 instead of 1. + force_first_basis_to_one: + If set to True, adjusts the first basis function so that it starts with a value of one. + This could be useful to capture the refractory period. References ---------- @@ -872,11 +866,11 @@ def __init__( n_basis_funcs: int, alpha: float = 1.0, extend_and_trim_last: bool = True, - clip_first: bool = False, + force_first_basis_to_one: bool = False, ) -> None: super().__init__(n_basis_funcs, alpha=alpha) self.extend_and_trim_last = extend_and_trim_last - self._clip_first = clip_first + self._force_first_basis_to_one = force_first_basis_to_one def _transform_samples(self, sample_pts: NDArray) -> NDArray: """ @@ -932,7 +926,7 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: eval_basis = super()._evaluate(self._transform_samples(sample_pts))[:, ::-1] eval_basis = eval_basis[..., :-n_trim] self.n_basis_funcs -= n_trim - if self._clip_first: + if self._force_first_basis_to_one: idx = np.argmin(np.abs(eval_basis[:, 0] - 1)) eval_basis[:idx, 0] = 1 From b23eebdb79c6ac213be0f91f4a6578c1938ac106 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 30 Nov 2023 14:52:11 -0500 Subject: [PATCH 10/33] updated default --- src/nemos/basis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index a243f615..b0ff03f1 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -864,7 +864,7 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear): def __init__( self, n_basis_funcs: int, - alpha: float = 1.0, + alpha: float = 1.5, extend_and_trim_last: bool = True, force_first_basis_to_one: bool = False, ) -> None: From 9936ca37856a9dbe466130efa0d37dc0196dd0da Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 30 Nov 2023 16:55:05 -0500 Subject: [PATCH 11/33] improved parametrization --- docs/examples/plot_1D_basis_function.py | 4 ++-- src/nemos/basis.py | 26 +++++++++++++------------ 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/docs/examples/plot_1D_basis_function.py b/docs/examples/plot_1D_basis_function.py index 1112e5c3..7cd8eb3a 100644 --- a/docs/examples/plot_1D_basis_function.py +++ b/docs/examples/plot_1D_basis_function.py @@ -69,11 +69,11 @@ # evaluate a log-spaced cosine raised function basis. # Instantiate the basis noting that the `RaisedCosineBasisLog` does not require an `order` parameter -raised_cosine_log = nmo.basis.RaisedCosineBasisLog(n_basis_funcs=21, alpha=1.5) +raised_cosine_log = nmo.basis.RaisedCosineBasisLog(n_basis_funcs=3, alpha=1.5, time_scaling=10) # Evaluate the raised cosine basis at the equi-spaced sample points # (same method in all Basis elements) -samples, eval_basis = raised_cosine_log.evaluate_on_grid(1000) +samples, eval_basis = raised_cosine_log.evaluate_on_grid(100) # Plot the evaluated log-spaced raised cosine basis plt.figure() diff --git a/src/nemos/basis.py b/src/nemos/basis.py index b0ff03f1..acc207f3 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -865,12 +865,18 @@ def __init__( self, n_basis_funcs: int, alpha: float = 1.5, + time_scaling: float = None, extend_and_trim_last: bool = True, force_first_basis_to_one: bool = False, ) -> None: super().__init__(n_basis_funcs, alpha=alpha) self.extend_and_trim_last = extend_and_trim_last self._force_first_basis_to_one = force_first_basis_to_one + if time_scaling is None: + self._time_scaling = np.pi * (self.n_basis_funcs - 1) * 10 + else: + self._time_scaling = time_scaling + print(f"time_scaling: {self._time_scaling}") def _transform_samples(self, sample_pts: NDArray) -> NDArray: """ @@ -888,15 +894,11 @@ def _transform_samples(self, sample_pts: NDArray) -> NDArray: Transformed version of the sample points that matches the Raised Cosine basis domain, shape (n_samples, ). """ - # if equi-spaced samples, this is equivalent to - # log_spaced_pts = np.logspace( - # np.log10((self.n_basis_funcs - 1) * np.pi), - # -1, - # sample_pts.shape[0] - # ) - 0.1 - # log_spaced_pts = log_spaced_pts / (np.pi * (self.n_basis_funcs - 1)) - base = np.pi * (self.n_basis_funcs - 1) * 10 - log_spaced_pts = base ** (-sample_pts) - 1 / base + # This is equivalent to log-spacing the points with base self._time_scaling + # and then adjust the extremes to be 0 and 1. + # as the base tends to 1, the points will be linearly spaced. + # as the base tends to inf, the points will be concentrated in 0. + log_spaced_pts = (self._time_scaling ** sample_pts - 1) / (self._time_scaling - 1) return log_spaced_pts def _evaluate(self, sample_pts: NDArray) -> NDArray: @@ -918,12 +920,12 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: If the sample provided do not lie in [0,1]. """ if not self.extend_and_trim_last: - eval_basis = super()._evaluate(self._transform_samples(sample_pts))[:, ::-1] + eval_basis = super()._evaluate(self._transform_samples(sample_pts))[::-1, ::-1]#[..., ::-1] else: # temporarily add a basis element - n_trim = int(np.round(self.alpha)) + n_trim = int(np.ceil(self.alpha)) self.n_basis_funcs += n_trim - eval_basis = super()._evaluate(self._transform_samples(sample_pts))[:, ::-1] + eval_basis = super()._evaluate(self._transform_samples(sample_pts))[::-1, ::-1] eval_basis = eval_basis[..., :-n_trim] self.n_basis_funcs -= n_trim if self._force_first_basis_to_one: From 16bd6849d2704cdc4bb9ef0a8125bc1237bd8ce9 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 7 Dec 2023 17:12:07 -0500 Subject: [PATCH 12/33] removed unused variable --- docs/examples/plot_1D_basis_function.py | 2 +- src/nemos/basis.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/examples/plot_1D_basis_function.py b/docs/examples/plot_1D_basis_function.py index 7cd8eb3a..327e8b7b 100644 --- a/docs/examples/plot_1D_basis_function.py +++ b/docs/examples/plot_1D_basis_function.py @@ -69,7 +69,7 @@ # evaluate a log-spaced cosine raised function basis. # Instantiate the basis noting that the `RaisedCosineBasisLog` does not require an `order` parameter -raised_cosine_log = nmo.basis.RaisedCosineBasisLog(n_basis_funcs=3, alpha=1.5, time_scaling=10) +raised_cosine_log = nmo.basis.RaisedCosineBasisLog(n_basis_funcs=10, alpha=1., time_scaling=50) # Evaluate the raised cosine basis at the equi-spaced sample points # (same method in all Basis elements) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index acc207f3..d8c17892 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -871,7 +871,6 @@ def __init__( ) -> None: super().__init__(n_basis_funcs, alpha=alpha) self.extend_and_trim_last = extend_and_trim_last - self._force_first_basis_to_one = force_first_basis_to_one if time_scaling is None: self._time_scaling = np.pi * (self.n_basis_funcs - 1) * 10 else: @@ -920,17 +919,18 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray: If the sample provided do not lie in [0,1]. """ if not self.extend_and_trim_last: - eval_basis = super()._evaluate(self._transform_samples(sample_pts))[::-1, ::-1]#[..., ::-1] + # flip the order of raised-cosine in the: + # - time axis: compression at the beginning of the interval + # - basis axis: set the first basis to be the one near t = 0 + eval_basis = super()._evaluate(self._transform_samples(sample_pts))[::-1, ::-1] else: - # temporarily add a basis element + # temporarily add n_trim basis element + # n_trim guarantees that the last basis element decays to 0. n_trim = int(np.ceil(self.alpha)) self.n_basis_funcs += n_trim eval_basis = super()._evaluate(self._transform_samples(sample_pts))[::-1, ::-1] eval_basis = eval_basis[..., :-n_trim] self.n_basis_funcs -= n_trim - if self._force_first_basis_to_one: - idx = np.argmin(np.abs(eval_basis[:, 0] - 1)) - eval_basis[:idx, 0] = 1 return eval_basis From b9bc61b010c8b9a7ed6b2e1004c23093d140a133 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 15 Dec 2023 11:04:49 -0500 Subject: [PATCH 13/33] added checks --- src/nemos/basis.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 8f8a4ad1..78e325dd 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -1000,6 +1000,7 @@ def evaluate(self, sample_pts: NDArray) -> NDArray: ValueError If the sample provided do not lie in [0,1]. """ + (sample_pts,) = self._check_evaluate_input(sample_pts) if not self.extend_and_trim_last: # keep the last basis, i.e. do not enforce decay to zero # for the filter. From e8f318cba31ed334a3e9b97813b2f9a7c49bbdfa Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 15 Dec 2023 11:06:45 -0500 Subject: [PATCH 14/33] linted --- src/nemos/__init__.py | 10 +--------- src/nemos/basis.py | 10 +++++++--- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/src/nemos/__init__.py b/src/nemos/__init__.py index 28836560..8bcee80f 100644 --- a/src/nemos/__init__.py +++ b/src/nemos/__init__.py @@ -1,11 +1,3 @@ #!/usr/bin/env python3 -from . import ( - basis, - exceptions, - glm, - observation_models, - regularizer, - simulation, - utils, -) +from . import basis, exceptions, glm, observation_models, regularizer, simulation, utils diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 78e325dd..1bbe0ab5 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -946,7 +946,7 @@ def __init__( super().__init__(n_basis_funcs, width=width) self.extend_and_trim_last = extend_and_trim_last if time_scaling is None: - self.time_scaling = 50. + self.time_scaling = 50.0 else: self.time_scaling = time_scaling @@ -957,7 +957,9 @@ def time_scaling(self): @time_scaling.setter def time_scaling(self, time_scaling): if time_scaling <= 0: - raise ValueError(f"Only strictly positive time_scaling are allowed, {time_scaling} provided instead.") + raise ValueError( + f"Only strictly positive time_scaling are allowed, {time_scaling} provided instead." + ) self._time_scaling = time_scaling def _transform_samples(self, sample_pts: NDArray) -> NDArray: @@ -979,7 +981,9 @@ def _transform_samples(self, sample_pts: NDArray) -> NDArray: # This log-stretching of the sample axis has the following effect: # - as the time_scaling tends to 0, the points will be linearly spaced. # - as the time_scaling tends to inf, basis will be dense around 0. - log_spaced_pts = np.log(self.time_scaling * sample_pts + 1) / np.log(self.time_scaling + 1) + log_spaced_pts = np.log(self.time_scaling * sample_pts + 1) / np.log( + self.time_scaling + 1 + ) return log_spaced_pts def evaluate(self, sample_pts: NDArray) -> NDArray: From b4b2e2c463fb4dd01dd447ef2b5af5f5b21f0fac Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 15 Dec 2023 12:08:20 -0500 Subject: [PATCH 15/33] fixed tests --- src/nemos/basis.py | 45 ++++---- tests/test_basis.py | 252 +++++++++++++++++++++++++++++++++++++------- 2 files changed, 237 insertions(+), 60 deletions(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 1bbe0ab5..cb4068d9 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -51,7 +51,7 @@ def __init__(self, n_basis_funcs: int) -> None: self._check_n_basis_min() @abc.abstractmethod - def evaluate(self, *xi: NDArray) -> NDArray: + def evaluate(self, *xi: ArrayLike) -> NDArray: """ Evaluate the basis set at the given samples x1,...,xn using the subclass-specific "evaluate" method. @@ -327,7 +327,7 @@ def __init__(self, basis1: Basis, basis2: Basis) -> None: def _check_n_basis_min(self) -> None: pass - def evaluate(self, *xi: NDArray) -> NDArray: + def evaluate(self, *xi: ArrayLike) -> NDArray: """ Evaluate the basis at the input samples. @@ -383,7 +383,7 @@ def __init__(self, basis1: Basis, basis2: Basis) -> None: def _check_n_basis_min(self) -> None: pass - def evaluate(self, *xi: NDArray) -> NDArray: + def evaluate(self, *xi: ArrayLike) -> NDArray: """ Evaluate the basis at the input samples. @@ -527,7 +527,7 @@ class MSplineBasis(SplineBasis): def __init__(self, n_basis_funcs: int, order: int = 2) -> None: super().__init__(n_basis_funcs, order) - def evaluate(self, sample_pts: NDArray) -> NDArray: + def evaluate(self, sample_pts: ArrayLike) -> NDArray: """Generate basis functions with given spacing. Parameters @@ -605,7 +605,7 @@ class BSplineBasis(SplineBasis): def __init__(self, n_basis_funcs: int, order: int = 2): super().__init__(n_basis_funcs, order=order) - def evaluate(self, sample_pts: NDArray) -> NDArray: + def evaluate(self, sample_pts: ArrayLike) -> NDArray: """ Evaluate the B-spline basis functions with given sample points. @@ -693,7 +693,7 @@ def __init__(self, n_basis_funcs: int, order: int = 2): f"order {self.order} specified instead!" ) - def evaluate(self, sample_pts: NDArray) -> NDArray: + def evaluate(self, sample_pts: ArrayLike) -> NDArray: """Evaluate the Cyclic B-spline basis functions with given sample points. Parameters @@ -793,7 +793,7 @@ class RaisedCosineBasisLinear(Basis): 11003–11013. http://dx.doi.org/10.1523/jneurosci.3305-05.2005 """ - def __init__(self, n_basis_funcs: int, width: float = 1.0) -> None: + def __init__(self, n_basis_funcs: int, width: float = 2.0) -> None: super().__init__(n_basis_funcs) self._n_input_dimensionality = 1 self.width = width @@ -821,18 +821,18 @@ def _check_width(width: float): Raises ------ ValueError - If alpha < 1 or 2*alpha is not a positive integer. Values that do not match + If width <= 1 or 2*width is not a positive integer. Values that do not match this constraint will result in: - - No overlap between bumps (alpha < 1). - - Oscillatory behavior when summing the basis elements (2*alpha not integer). + - No overlap between bumps (width < 1). + - Oscillatory behavior when summing the basis elements (2*width not integer). """ - if width < 1 or (not np.isclose(width * 2, round(2 * width))): + if width <= 1 or (not np.isclose(width * 2, round(2 * width))): raise ValueError( f"Invalid raised cosine width. " - f"2*alpha must be a positive integer, 2*alpha = {2 * width} instead!" + f"2*width must be a positive integer, 2*width = {2 * width} instead!" ) - def evaluate(self, sample_pts: NDArray) -> NDArray: + def evaluate(self, sample_pts: ArrayLike) -> NDArray: """Generate basis functions with given samples. Parameters @@ -920,13 +920,10 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear): The number of basis functions. width : Width of the raised cosine. By default, it's set to 1.0. - extend_and_trim_last: - If set to True, the algorithm first constructs a basis with `n_basis_funcs + 1` elements - and subsequently trims off the last basis element. This ensures that the final basis element + enforce_decay_to_zero: + If set to True, the algorithm first constructs a basis with `n_basis_funcs + ceil(alpha)` elements + and subsequently trims off the extra basis elements. This ensures that the final basis element concludes at a value of 0 instead of 1. - force_first_basis_to_one: - If set to True, adjusts the first basis function so that it starts with a value of one. - This could be useful to capture the refractory period. References ---------- @@ -939,12 +936,12 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear): def __init__( self, n_basis_funcs: int, - width: float = 1.5, + width: float = 2.0, time_scaling: float = None, - extend_and_trim_last: bool = True, + enforce_decay_to_zero: bool = True, ) -> None: super().__init__(n_basis_funcs, width=width) - self.extend_and_trim_last = extend_and_trim_last + self.enforce_decay_to_zero = enforce_decay_to_zero if time_scaling is None: self.time_scaling = 50.0 else: @@ -986,7 +983,7 @@ def _transform_samples(self, sample_pts: NDArray) -> NDArray: ) return log_spaced_pts - def evaluate(self, sample_pts: NDArray) -> NDArray: + def evaluate(self, sample_pts: ArrayLike) -> NDArray: """Generate log-spaced raised cosine basis with given samples. Parameters @@ -1005,7 +1002,7 @@ def evaluate(self, sample_pts: NDArray) -> NDArray: If the sample provided do not lie in [0,1]. """ (sample_pts,) = self._check_evaluate_input(sample_pts) - if not self.extend_and_trim_last: + if not self.enforce_decay_to_zero: # keep the last basis, i.e. do not enforce decay to zero # for the filter. eval_basis = super().evaluate(self._transform_samples(sample_pts)) diff --git a/tests/test_basis.py b/tests/test_basis.py index 57f1f5ab..fd757d54 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -156,9 +156,14 @@ def test_number_of_required_inputs_evaluate(self, n_input): basis_obj = self.cls(n_basis_funcs=5) inputs = [np.linspace(0, 1, 20)] * n_input if n_input == 0: - expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + expectation = pytest.raises( + TypeError, match="evaluate\(\) missing 1 required positional argument" + ) elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) + expectation = pytest.raises( + TypeError, + match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given", + ) else: expectation = does_not_raise() with expectation: @@ -204,14 +209,94 @@ def test_evaluate_on_grid_input_number(self, n_input): basis_obj = self.cls(n_basis_funcs=5) inputs = [10] * n_input if n_input == 0: - expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + expectation = pytest.raises( + TypeError, + match="evaluate_on_grid\(\) missing 1 required positional argument", + ) elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) + expectation = pytest.raises( + TypeError, + match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", + ) else: expectation = does_not_raise() with expectation: basis_obj.evaluate_on_grid(*inputs) + @pytest.mark.parametrize( + "width ,expectation", + [ + (-1, pytest.raises(ValueError, match="Invalid raised cosine width. ")), + (0, pytest.raises(ValueError, match="Invalid raised cosine width. ")), + (0.5, pytest.raises(ValueError, match="Invalid raised cosine width. ")), + (1, pytest.raises(ValueError, match="Invalid raised cosine width. ")), + (1.5, does_not_raise()), + (2, does_not_raise()), + (2.1, pytest.raises(ValueError, match="Invalid raised cosine width. ")), + ], + ) + def test_width_values(self, width, expectation): + """Test allowable widths: integer multiple of 1/2, greater than 1.""" + with expectation: + self.cls(n_basis_funcs=5, width=width) + + @pytest.mark.parametrize("width", [1.5, 2, 2.5]) + def test_decay_to_zero_basis_number_match(self, width): + """Test that the number of basis is preserved.""" + n_basis_funcs = 10 + _, ev = self.cls( + n_basis_funcs=n_basis_funcs, width=width, enforce_decay_to_zero=True + ).evaluate_on_grid(2) + assert ev.shape[1] == n_basis_funcs, ( + "Basis function number mismatch. " + f"Expected {n_basis_funcs}, got {ev.shape[1]} instead!" + ) + + @pytest.mark.parametrize( + "time_scaling ,expectation", + [ + ( + -1, + pytest.raises( + ValueError, match="Only strictly positive time_scaling are allowed" + ), + ), + ( + 0, + pytest.raises( + ValueError, match="Only strictly positive time_scaling are allowed" + ), + ), + (0.1, does_not_raise()), + (10, does_not_raise()), + ], + ) + def test_time_scaling_values(self, time_scaling, expectation): + """Test that only positive time_scaling are allowed.""" + with expectation: + self.cls(n_basis_funcs=5, time_scaling=time_scaling) + + def test_time_scaling_property(self): + """Test that larger time_scaling results in larger departures from linearity.""" + time_scaling = [0.1, 10, 100] + n_basis_funcs = 5 + _, lin_ev = basis.RaisedCosineBasisLinear(n_basis_funcs).evaluate_on_grid(100) + corr = np.zeros(len(time_scaling)) + for idx, ts in enumerate(time_scaling): + # set default decay to zero to get comparable basis + basis_log = self.cls( + n_basis_funcs=n_basis_funcs, + time_scaling=ts, + enforce_decay_to_zero=False, + ) + _, log_ev = basis_log.evaluate_on_grid(100) + # compute the correlation + corr[idx] = (lin_ev.flatten() @ log_ev.flatten()) / ( + np.linalg.norm(lin_ev.flatten()) * np.linalg.norm(log_ev.flatten()) + ) + # check that the correlation decreases as time_scale increases + assert np.all(np.diff(corr) < 0), "As time scales increases, deviation from linearity should increase!" + class TestRaisedCosineLinearBasis(BasisFuncsTesting): cls = basis.RaisedCosineBasisLinear @@ -278,8 +363,11 @@ def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs): """ raise_exception = n_basis_funcs < 2 if raise_exception: - with pytest.raises(ValueError, match=f"Object class {self.cls.__name__} " - r"requires >= 2 basis elements\."): + with pytest.raises( + ValueError, + match=f"Object class {self.cls.__name__} " + r"requires >= 2 basis elements\.", + ): self.cls(n_basis_funcs=n_basis_funcs) else: self.cls(n_basis_funcs=n_basis_funcs) @@ -309,9 +397,14 @@ def test_number_of_required_inputs_evaluate(self, n_input): basis_obj = self.cls(n_basis_funcs=5) inputs = [np.linspace(0, 1, 20)] * n_input if n_input == 0: - expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + expectation = pytest.raises( + TypeError, match="evaluate\(\) missing 1 required positional argument" + ) elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) + expectation = pytest.raises( + TypeError, + match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given", + ) else: expectation = does_not_raise() with expectation: @@ -357,14 +450,37 @@ def test_evaluate_on_grid_input_number(self, n_input): basis_obj = self.cls(n_basis_funcs=5) inputs = [10] * n_input if n_input == 0: - expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + expectation = pytest.raises( + TypeError, + match="evaluate_on_grid\(\) missing 1 required positional argument", + ) elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) + expectation = pytest.raises( + TypeError, + match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", + ) else: expectation = does_not_raise() with expectation: basis_obj.evaluate_on_grid(*inputs) + @pytest.mark.parametrize( + "width ,expectation", + [ + (-1, pytest.raises(ValueError, match="Invalid raised cosine width. ")), + (0, pytest.raises(ValueError, match="Invalid raised cosine width. ")), + (0.5, pytest.raises(ValueError, match="Invalid raised cosine width. ")), + (1, pytest.raises(ValueError, match="Invalid raised cosine width. ")), + (1.5, does_not_raise()), + (2, does_not_raise()), + (2.1, pytest.raises(ValueError, match="Invalid raised cosine width. ")), + ], + ) + def test_width_values(self, width, expectation): + """Test allowable widths: integer multiple of 1/2, greater than 1.""" + with expectation: + self.cls(n_basis_funcs=5, width=width) + class TestMSplineBasis(BasisFuncsTesting): cls = basis.MSplineBasis @@ -463,9 +579,14 @@ def test_number_of_required_inputs_evaluate(self, n_input): basis_obj = self.cls(n_basis_funcs=5, order=3) inputs = [np.linspace(0, 1, 20)] * n_input if n_input == 0: - expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + expectation = pytest.raises( + TypeError, match="evaluate\(\) missing 1 required positional argument" + ) elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) + expectation = pytest.raises( + TypeError, + match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given", + ) else: expectation = does_not_raise() with expectation: @@ -511,9 +632,15 @@ def test_evaluate_on_grid_input_number(self, n_input): basis_obj = self.cls(n_basis_funcs=5, order=3) inputs = [10] * n_input if n_input == 0: - expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + expectation = pytest.raises( + TypeError, + match="evaluate_on_grid\(\) missing 1 required positional argument", + ) elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) + expectation = pytest.raises( + TypeError, + match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", + ) else: expectation = does_not_raise() with expectation: @@ -535,7 +662,8 @@ def test_non_empty_samples(self, samples): self.cls(5, decay_rates=np.arange(1, 6)).evaluate(samples) @pytest.mark.parametrize( - "eval_input", [0, [0] * 6, (0,) * 6, np.array([0] * 6), jax.numpy.array([0] * 6)] + "eval_input", + [0, [0] * 6, (0,) * 6, np.array([0] * 6), jax.numpy.array([0] * 6)], ) def test_evaluate_input(self, eval_input): """ @@ -544,7 +672,10 @@ def test_evaluate_input(self, eval_input): basis_obj = self.cls(n_basis_funcs=5, decay_rates=np.arange(1, 6)) if isinstance(eval_input, int): # OrthExponentialBasis is special -- cannot accept int input - with pytest.raises(ValueError, match="OrthExponentialBasis requires at least as many samples"): + with pytest.raises( + ValueError, + match="OrthExponentialBasis requires at least as many samples", + ): basis_obj.evaluate(eval_input) else: basis_obj.evaluate(eval_input) @@ -623,9 +754,14 @@ def test_number_of_required_inputs_evaluate(self, n_input): basis_obj = self.cls(n_basis_funcs=5, decay_rates=np.arange(1, 6)) inputs = [np.linspace(0, 1, 20)] * n_input if n_input == 0: - expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + expectation = pytest.raises( + TypeError, match="evaluate\(\) missing 1 required positional argument" + ) elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) + expectation = pytest.raises( + TypeError, + match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given", + ) else: expectation = does_not_raise() with expectation: @@ -670,9 +806,15 @@ def test_evaluate_on_grid_input_number(self, n_input): basis_obj = self.cls(n_basis_funcs=5, decay_rates=np.arange(1, 6)) inputs = [10] * n_input if n_input == 0: - expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + expectation = pytest.raises( + TypeError, + match="evaluate_on_grid\(\) missing 1 required positional argument", + ) elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) + expectation = pytest.raises( + TypeError, + match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", + ) else: expectation = does_not_raise() with expectation: @@ -826,9 +968,14 @@ def test_number_of_required_inputs_evaluate(self, n_input): basis_obj = self.cls(n_basis_funcs=5, order=3) inputs = [np.linspace(0, 1, 20)] * n_input if n_input == 0: - expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + expectation = pytest.raises( + TypeError, match="evaluate\(\) missing 1 required positional argument" + ) elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) + expectation = pytest.raises( + TypeError, + match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given", + ) else: expectation = does_not_raise() with expectation: @@ -878,9 +1025,15 @@ def test_evaluate_on_grid_input_number(self, n_input): basis_obj = self.cls(n_basis_funcs=5, order=3) inputs = [10] * n_input if n_input == 0: - expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + expectation = pytest.raises( + TypeError, + match="evaluate_on_grid\(\) missing 1 required positional argument", + ) elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) + expectation = pytest.raises( + TypeError, + match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", + ) else: expectation = does_not_raise() with expectation: @@ -1017,9 +1170,14 @@ def test_number_of_required_inputs_evaluate(self, n_input): basis_obj = self.cls(n_basis_funcs=5, order=3) inputs = [np.linspace(0, 1, 20)] * n_input if n_input == 0: - expectation = pytest.raises(TypeError, match="evaluate\(\) missing 1 required positional argument") + expectation = pytest.raises( + TypeError, match="evaluate\(\) missing 1 required positional argument" + ) elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises(TypeError, match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given",) + expectation = pytest.raises( + TypeError, + match="evaluate\(\) takes [0-9] positional arguments but [0-9] were given", + ) else: expectation = does_not_raise() with expectation: @@ -1069,9 +1227,15 @@ def test_evaluate_on_grid_input_number(self, n_input): basis_obj = self.cls(n_basis_funcs=5, order=3) inputs = [10] * n_input if n_input == 0: - expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) missing 1 required positional argument") + expectation = pytest.raises( + TypeError, + match="evaluate_on_grid\(\) missing 1 required positional argument", + ) elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises(TypeError, match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given",) + expectation = pytest.raises( + TypeError, + match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", + ) else: expectation = does_not_raise() with expectation: @@ -1235,10 +1399,14 @@ def test_number_of_required_inputs_evaluate( basis_a_obj = self.instantiate_basis(n_basis_a, basis_a) basis_b_obj = self.instantiate_basis(n_basis_b, basis_b) basis_obj = basis_a_obj + basis_b_obj - required_dim = basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality + required_dim = ( + basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality + ) inputs = [np.linspace(0, 1, 20)] * n_input if n_input != required_dim: - expectation = pytest.raises(TypeError, match="Input dimensionality mismatch.") + expectation = pytest.raises( + TypeError, match="Input dimensionality mismatch." + ) else: expectation = does_not_raise() with expectation: @@ -1317,9 +1485,13 @@ def test_evaluate_on_grid_input_number( basis_b_obj = self.instantiate_basis(n_basis_b, basis_b) basis_obj = basis_a_obj + basis_b_obj inputs = [20] * n_input - required_dim = basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality + required_dim = ( + basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality + ) if n_input != required_dim: - expectation = pytest.raises(TypeError, match="Input dimensionality mismatch.") + expectation = pytest.raises( + TypeError, match="Input dimensionality mismatch." + ) else: expectation = does_not_raise() with expectation: @@ -1443,10 +1615,14 @@ def test_number_of_required_inputs_evaluate( basis_a_obj = self.instantiate_basis(n_basis_a, basis_a) basis_b_obj = self.instantiate_basis(n_basis_b, basis_b) basis_obj = basis_a_obj * basis_b_obj - required_dim = basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality + required_dim = ( + basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality + ) inputs = [np.linspace(0, 1, 20)] * n_input if n_input != required_dim: - expectation = pytest.raises(TypeError, match="Input dimensionality mismatch.") + expectation = pytest.raises( + TypeError, match="Input dimensionality mismatch." + ) else: expectation = does_not_raise() with expectation: @@ -1525,9 +1701,13 @@ def test_evaluate_on_grid_input_number( basis_b_obj = self.instantiate_basis(n_basis_b, basis_b) basis_obj = basis_a_obj * basis_b_obj inputs = [20] * n_input - required_dim = basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality + required_dim = ( + basis_a_obj._n_input_dimensionality + basis_b_obj._n_input_dimensionality + ) if n_input != required_dim: - expectation = pytest.raises(TypeError, match="Input dimensionality mismatch.") + expectation = pytest.raises( + TypeError, match="Input dimensionality mismatch." + ) else: expectation = does_not_raise() with expectation: From ff32160c891703d2d2bd3feda7794a874a8a16d7 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 15 Dec 2023 12:15:04 -0500 Subject: [PATCH 16/33] fixed refs --- docs/examples/plot_1D_basis_function.py | 2 +- src/nemos/basis.py | 30 +++++++++++-------------- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/docs/examples/plot_1D_basis_function.py b/docs/examples/plot_1D_basis_function.py index 827920d3..37ce1e1a 100644 --- a/docs/examples/plot_1D_basis_function.py +++ b/docs/examples/plot_1D_basis_function.py @@ -69,7 +69,7 @@ # evaluate a log-spaced cosine raised function basis. # Instantiate the basis noting that the `RaisedCosineBasisLog` does not require an `order` parameter -raised_cosine_log = nmo.basis.RaisedCosineBasisLog(n_basis_funcs=10, width=1., time_scaling=50) +raised_cosine_log = nmo.basis.RaisedCosineBasisLog(n_basis_funcs=10, width=1.5, time_scaling=50) # Evaluate the raised cosine basis at the equi-spaced sample points # (same method in all Basis elements) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index cb4068d9..86517ecc 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -505,7 +505,7 @@ def _check_n_basis_min(self) -> None: class MSplineBasis(SplineBasis): - """M-spline 1-dimensional basis functions. + """M-spline[$^1$](references) 1-dimensional basis functions. Parameters ---------- @@ -519,8 +519,7 @@ class MSplineBasis(SplineBasis): References ---------- - [^1]: - Ramsay, J. O. (1988). Monotone regression splines in action. + 1. Ramsay, J. O. (1988). Monotone regression splines in action. Statistical science, 3(4), 425-441. """ @@ -577,7 +576,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: class BSplineBasis(SplineBasis): """ - B-spline 1-dimensional basis functions. + B-spline[$^1$](references) 1-dimensional basis functions. Parameters ---------- @@ -596,8 +595,7 @@ class BSplineBasis(SplineBasis): References ---------- - [^2]: - Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques. + 1. Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques. Mathematics and Visualization. Springer, Berlin, Heidelberg. https://doi.org/10.1007/978-3-662-04919-8_5 """ @@ -774,20 +772,19 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: class RaisedCosineBasisLinear(Basis): """Represent linearly-spaced raised cosine basis functions. - This implementation is based on the cosine bumps used by Pillow et al. [2] - to uniformly tile the domain (if alpha = 1) or the internal points of the domain - (if alpha > 1). + This implementation is based on the cosine bumps used by Pillow et al.[$^1$](references) + to uniformly tile the internal points of the domain. Parameters ---------- n_basis_funcs : The number of basis functions. width : - Width of the raised cosine. By default, it's set to 1.0. + Width of the raised cosine. By default, it's set to 2.0. References ---------- - [2] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., + 1. Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., C. E. (2005). Prediction and decoding of retinal ganglion cell responses with a probabilistic spiking model. Journal of Neuroscience, 25(47), 11003–11013. http://dx.doi.org/10.1523/jneurosci.3305-05.2005 @@ -910,24 +907,23 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear): """Represent log-spaced raised cosine basis functions. Similar to `RaisedCosineBasisLinear` but the basis functions are log-spaced. - This implementation is based on the cosine bumps used by Pillow et al. [2] - to uniformly tile the domain (if alpha = 1) or the internal points of the domain - (if alpha > 1). + This implementation is based on the cosine bumps used by Pillow et al.[$^1$](#references) + to uniformly tile the internal points of the domain. Parameters ---------- n_basis_funcs : The number of basis functions. width : - Width of the raised cosine. By default, it's set to 1.0. + Width of the raised cosine. By default, it's set to 2.0. enforce_decay_to_zero: - If set to True, the algorithm first constructs a basis with `n_basis_funcs + ceil(alpha)` elements + If set to True, the algorithm first constructs a basis with `n_basis_funcs + ceil(width)` elements and subsequently trims off the extra basis elements. This ensures that the final basis element concludes at a value of 0 instead of 1. References ---------- - [2] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., + 1. Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., C. E. (2005). Prediction and decoding of retinal ganglion cell responses with a probabilistic spiking model. Journal of Neuroscience, 25(47), 11003–11013. http://dx.doi.org/10.1523/jneurosci.3305-05.2005 From 06757d9e508e1b92eeea116d3ab9c19573911a38 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 15 Dec 2023 12:16:06 -0500 Subject: [PATCH 17/33] fixed note --- docs/developers_notes/01-basis_module.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/developers_notes/01-basis_module.md b/docs/developers_notes/01-basis_module.md index ebb2ace5..c0bdf6f8 100644 --- a/docs/developers_notes/01-basis_module.md +++ b/docs/developers_notes/01-basis_module.md @@ -19,9 +19,7 @@ Abstract Class Basis │ │ │ └─ Concrete Subclass CyclicBSplineBasis │ -├─ Abstract Subclass RaisedCosineBasis -│ │ -│ ├─ Concrete Subclass RaisedCosineBasisLinear +├─ Concrete Subclass RaisedCosineBasisLinear │ │ │ └─ Concrete Subclass RaisedCosineBasisLog │ From 664169d9d966e0ebe7ea42752b749b89c654fa45 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 15 Dec 2023 14:19:22 -0500 Subject: [PATCH 18/33] improved docstring --- src/nemos/basis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 86517ecc..2032c3b1 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -919,7 +919,7 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear): enforce_decay_to_zero: If set to True, the algorithm first constructs a basis with `n_basis_funcs + ceil(width)` elements and subsequently trims off the extra basis elements. This ensures that the final basis element - concludes at a value of 0 instead of 1. + decays to 0. References ---------- From 4d58d3297d261a58ce7a4ee69915c70a875949d7 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 19 Dec 2023 12:46:18 -0500 Subject: [PATCH 19/33] reversed ci change to only ubuntu --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 83683b8a..d66a08b5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,7 +16,7 @@ jobs: tox: strategy: matrix: - os: [ubuntu-latest, macos-latest] #[ubuntu-latest, macos-latest, windows-latest] + os: [ubuntu-latest] #[ubuntu-latest, macos-latest, windows-latest] python-version: ['3.10'] #['3.8', '3.9', '3.10'] runs-on: ${{ matrix.os }} steps: From 1fcd44633c79aa72aacc068656d8517886a475e6 Mon Sep 17 00:00:00 2001 From: Edoardo Balzani Date: Tue, 19 Dec 2023 12:46:26 -0500 Subject: [PATCH 20/33] Update src/nemos/basis.py Co-authored-by: William F. Broderick --- src/nemos/basis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 2032c3b1..962ec222 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -505,7 +505,7 @@ def _check_n_basis_min(self) -> None: class MSplineBasis(SplineBasis): - """M-spline[$^1$](references) 1-dimensional basis functions. + """M-spline[$^1$](#references) 1-dimensional basis functions. Parameters ---------- From fd3d1d4b59bd1be57a60d4cde830f1d0bad5e952 Mon Sep 17 00:00:00 2001 From: Edoardo Balzani Date: Tue, 19 Dec 2023 12:47:14 -0500 Subject: [PATCH 21/33] Update src/nemos/basis.py Co-authored-by: William F. Broderick --- src/nemos/basis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 962ec222..25c7545f 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -576,7 +576,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: class BSplineBasis(SplineBasis): """ - B-spline[$^1$](references) 1-dimensional basis functions. + B-spline[$^1$](#references) 1-dimensional basis functions. Parameters ---------- From 93a1b1e69720962055790e0e98bf933402c8b227 Mon Sep 17 00:00:00 2001 From: Edoardo Balzani Date: Tue, 19 Dec 2023 12:47:22 -0500 Subject: [PATCH 22/33] Update src/nemos/basis.py Co-authored-by: William F. Broderick --- src/nemos/basis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 25c7545f..dcf9ee85 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -772,7 +772,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: class RaisedCosineBasisLinear(Basis): """Represent linearly-spaced raised cosine basis functions. - This implementation is based on the cosine bumps used by Pillow et al.[$^1$](references) + This implementation is based on the cosine bumps used by Pillow et al.[$^1$](#references) to uniformly tile the internal points of the domain. Parameters From de1a89504da12860f57ab1a3407b6ecc93f2395d Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 19 Dec 2023 12:57:52 -0500 Subject: [PATCH 23/33] made width a read only property, added comment on evaluate --- src/nemos/basis.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 2032c3b1..fcfb835f 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -793,19 +793,14 @@ class RaisedCosineBasisLinear(Basis): def __init__(self, n_basis_funcs: int, width: float = 2.0) -> None: super().__init__(n_basis_funcs) self._n_input_dimensionality = 1 - self.width = width + self._check_width(width) + self._width = width @property def width(self): """Return width of the raised cosine.""" return self._width - @width.setter - def width(self, width: float): - """Check and set width of the raised cosine.""" - self._check_width(width) - self._width = width - @staticmethod def _check_width(width: float): """Validate the width value. @@ -854,6 +849,9 @@ def evaluate(self, sample_pts: ArrayLike) -> NDArray: peaks = np.linspace(0, 1, self.n_basis_funcs) delta = peaks[1] - peaks[0] + # generate a set of shifted cosines, and constrain them to be non-zero + # over a single period, then enforce the codomain to be [0,1], by adding 1 + # and then multiply by 0.5 basis_funcs = 0.5 * ( np.cos( np.clip( From cf4a60de81a49f81c31c6bec89063e36d91a7fff Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 19 Dec 2023 13:02:47 -0500 Subject: [PATCH 24/33] getter only for time scaling --- src/nemos/basis.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index fcfb835f..448e520f 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -937,21 +937,21 @@ def __init__( super().__init__(n_basis_funcs, width=width) self.enforce_decay_to_zero = enforce_decay_to_zero if time_scaling is None: - self.time_scaling = 50.0 + self._time_scaling = 50.0 else: - self.time_scaling = time_scaling + self._check_time_scaling(time_scaling) + self._time_scaling = time_scaling @property def time_scaling(self): return self._time_scaling - @time_scaling.setter - def time_scaling(self, time_scaling): + @staticmethod + def _check_time_scaling(time_scaling): if time_scaling <= 0: raise ValueError( f"Only strictly positive time_scaling are allowed, {time_scaling} provided instead." ) - self._time_scaling = time_scaling def _transform_samples(self, sample_pts: NDArray) -> NDArray: """ @@ -959,13 +959,12 @@ def _transform_samples(self, sample_pts: NDArray) -> NDArray: Parameters ---------- - sample_pts : NDArray + sample_pts : Sample points used for evaluating the splines, shape (n_samples, ). Returns ------- - NDArray Transformed version of the sample points that matches the Raised Cosine basis domain, shape (n_samples, ). """ From 120afe187f136386cdc16347b594dc123cb20da3 Mon Sep 17 00:00:00 2001 From: Edoardo Balzani Date: Tue, 19 Dec 2023 13:02:52 -0500 Subject: [PATCH 25/33] Update src/nemos/basis.py Co-authored-by: William F. Broderick --- src/nemos/basis.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index dcf9ee85..567f5430 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -972,8 +972,8 @@ def _transform_samples(self, sample_pts: NDArray) -> NDArray: shape (n_samples, ). """ # This log-stretching of the sample axis has the following effect: - # - as the time_scaling tends to 0, the points will be linearly spaced. - # - as the time_scaling tends to inf, basis will be dense around 0. + # - as the time_scaling tends to 0, the points will be linearly spaced across the whole domain. + # - as the time_scaling tends to inf, basis will be small and dense around 0 and progressively larger and less dense towards 1. log_spaced_pts = np.log(self.time_scaling * sample_pts + 1) / np.log( self.time_scaling + 1 ) From b86760ef705b04160858bf074df337417d11316b Mon Sep 17 00:00:00 2001 From: Edoardo Balzani Date: Tue, 19 Dec 2023 13:08:37 -0500 Subject: [PATCH 26/33] Update tests/test_basis.py Co-authored-by: William F. Broderick --- tests/test_basis.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index fd757d54..96760840 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -255,18 +255,8 @@ def test_decay_to_zero_basis_number_match(self, width): @pytest.mark.parametrize( "time_scaling ,expectation", [ - ( - -1, - pytest.raises( - ValueError, match="Only strictly positive time_scaling are allowed" - ), - ), - ( - 0, - pytest.raises( - ValueError, match="Only strictly positive time_scaling are allowed" - ), - ), + (-1, pytest.raises(ValueError, match="Only strictly positive time_scaling are allowed")), + (0, pytest.raises(ValueError, match="Only strictly positive time_scaling are allowed")), (0.1, does_not_raise()), (10, does_not_raise()), ], From 50fec49cf4f932648738c4ff48867427e8eb0fe3 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 19 Dec 2023 13:13:13 -0500 Subject: [PATCH 27/33] linted --- src/nemos/basis.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 43cbe818..ab1d64d1 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -970,7 +970,8 @@ def _transform_samples(self, sample_pts: NDArray) -> NDArray: """ # This log-stretching of the sample axis has the following effect: # - as the time_scaling tends to 0, the points will be linearly spaced across the whole domain. - # - as the time_scaling tends to inf, basis will be small and dense around 0 and progressively larger and less dense towards 1. + # - as the time_scaling tends to inf, basis will be small and dense around 0 and + # progressively larger and less dense towards 1. log_spaced_pts = np.log(self.time_scaling * sample_pts + 1) / np.log( self.time_scaling + 1 ) From 5a1a350c10abe519a7eac785db70d3f88f1a1c3c Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 19 Dec 2023 14:49:16 -0500 Subject: [PATCH 28/33] modified enforcing basis to zero --- src/nemos/basis.py | 50 ++++++++++++++++++++++++++++------------------ 1 file changed, 31 insertions(+), 19 deletions(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index ab1d64d1..47bf5c56 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -847,7 +847,7 @@ def evaluate(self, sample_pts: ArrayLike) -> NDArray: if any(sample_pts < 0) or any(sample_pts > 1): raise ValueError("Sample points for RaisedCosine basis must lie in [0,1]!") - peaks = np.linspace(0, 1, self.n_basis_funcs) + peaks = self._compute_peaks() delta = peaks[1] - peaks[0] # generate a set of shifted cosines, and constrain them to be non-zero # over a single period, then enforce the codomain to be [0,1], by adding 1 @@ -865,6 +865,16 @@ def evaluate(self, sample_pts: ArrayLike) -> NDArray: return basis_funcs + def _compute_peaks(self): + """ + Compute the location of raised cosine peaks. + + Returns + ------- + Peak locations of each basis element. + """ + return np.linspace(0, 1, self.n_basis_funcs) + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """Evaluate the basis set on a grid of equi-spaced sample points. @@ -977,6 +987,25 @@ def _transform_samples(self, sample_pts: NDArray) -> NDArray: ) return log_spaced_pts + def _compute_peaks(self): + """ + Peak location of each log-spaced cosine basis element + + Compute the peak location for the log-spaced raised cosine basis. + Enforcing that the last basis decays to zero is equivalent to + setting the last peak to a value smaller than 1. + + Returns + ------- + Peak locations of each basis element. + + """ + if self.enforce_decay_to_zero: + last_peak = 1 - np.ceil(self.width) / (self.n_basis_funcs + np.ceil(self.width) - 1) + return np.linspace(0, last_peak, self.n_basis_funcs) + else: + super()._compute_peaks() + def evaluate(self, sample_pts: ArrayLike) -> NDArray: """Generate log-spaced raised cosine basis with given samples. @@ -996,24 +1025,7 @@ def evaluate(self, sample_pts: ArrayLike) -> NDArray: If the sample provided do not lie in [0,1]. """ (sample_pts,) = self._check_evaluate_input(sample_pts) - if not self.enforce_decay_to_zero: - # keep the last basis, i.e. do not enforce decay to zero - # for the filter. - eval_basis = super().evaluate(self._transform_samples(sample_pts)) - else: - # temporarily add n_trim basis element - # guaranteeing that the last basis element decays to 0 for any width. - n_trim = int(np.ceil(self.width)) - self.n_basis_funcs += n_trim - # wrap the evaluation in a try -> finally to make sure that the original - # basis function number is preserved even if an exception is raised. - try: - eval_basis = super().evaluate(self._transform_samples(sample_pts)) - eval_basis = eval_basis[..., :-n_trim] - finally: - self.n_basis_funcs -= n_trim - - return eval_basis + return super().evaluate(self._transform_samples(sample_pts)) class OrthExponentialBasis(Basis): From 753ed207368d1b023695a4d9af08b69174178b67 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 19 Dec 2023 14:51:24 -0500 Subject: [PATCH 29/33] improved code linting --- src/nemos/basis.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 47bf5c56..075bc27b 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -1002,9 +1002,9 @@ def _compute_peaks(self): """ if self.enforce_decay_to_zero: last_peak = 1 - np.ceil(self.width) / (self.n_basis_funcs + np.ceil(self.width) - 1) - return np.linspace(0, last_peak, self.n_basis_funcs) else: - super()._compute_peaks() + last_peak = 1 + return np.linspace(0, last_peak, self.n_basis_funcs) def evaluate(self, sample_pts: ArrayLike) -> NDArray: """Generate log-spaced raised cosine basis with given samples. From 348f8afbce335dc3ae09436551fe30da88626836 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 19 Dec 2023 14:52:30 -0500 Subject: [PATCH 30/33] linted --- src/nemos/basis.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 075bc27b..8e2b48ff 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -1001,7 +1001,9 @@ def _compute_peaks(self): """ if self.enforce_decay_to_zero: - last_peak = 1 - np.ceil(self.width) / (self.n_basis_funcs + np.ceil(self.width) - 1) + last_peak = 1 - np.ceil(self.width) / ( + self.n_basis_funcs + np.ceil(self.width) - 1 + ) else: last_peak = 1 return np.linspace(0, last_peak, self.n_basis_funcs) From e168cc4610ccbf85721c2201cf6e9c636e7754d2 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 19 Dec 2023 15:04:07 -0500 Subject: [PATCH 31/33] changed code to have the very last time point being 0 --- src/nemos/basis.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 8e2b48ff..1da8a50d 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -1001,8 +1001,8 @@ def _compute_peaks(self): """ if self.enforce_decay_to_zero: - last_peak = 1 - np.ceil(self.width) / ( - self.n_basis_funcs + np.ceil(self.width) - 1 + last_peak = 1 - self.width / ( + self.n_basis_funcs + self.width - 1 ) else: last_peak = 1 From 5785acc5a4ecb22eb211a4dd32ab04e9320f9f4c Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 19 Dec 2023 15:07:03 -0500 Subject: [PATCH 32/33] added comment --- src/nemos/basis.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 1da8a50d..a0159ef6 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -1001,6 +1001,8 @@ def _compute_peaks(self): """ if self.enforce_decay_to_zero: + # compute the last peak location such that the last + # basis element decays to zero at the last sample. last_peak = 1 - self.width / ( self.n_basis_funcs + self.width - 1 ) From 14055b6509057142e5bae3274cdd65e4927bf938 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 19 Dec 2023 15:09:36 -0500 Subject: [PATCH 33/33] linted --- src/nemos/basis.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index a0159ef6..44150e68 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -1003,9 +1003,7 @@ def _compute_peaks(self): if self.enforce_decay_to_zero: # compute the last peak location such that the last # basis element decays to zero at the last sample. - last_peak = 1 - self.width / ( - self.n_basis_funcs + self.width - 1 - ) + last_peak = 1 - self.width / (self.n_basis_funcs + self.width - 1) else: last_peak = 1 return np.linspace(0, last_peak, self.n_basis_funcs)