diff --git a/docs/how_to_guide/plot_04_population_glm.py b/docs/how_to_guide/plot_04_population_glm.py index 84282477..70dac9cd 100644 --- a/docs/how_to_guide/plot_04_population_glm.py +++ b/docs/how_to_guide/plot_04_population_glm.py @@ -23,9 +23,10 @@ """ import jax.numpy as jnp -import nemos as nmo -import numpy as np import matplotlib.pyplot as plt +import numpy as np + +import nemos as nmo np.random.seed(123) diff --git a/docs/how_to_guide/plot_05_batch_glm.py b/docs/how_to_guide/plot_05_batch_glm.py index f9e758fc..84f64d98 100644 --- a/docs/how_to_guide/plot_05_batch_glm.py +++ b/docs/how_to_guide/plot_05_batch_glm.py @@ -6,10 +6,11 @@ """ +import matplotlib.pyplot as plt +import numpy as np import pynapple as nap + import nemos as nmo -import numpy as np -import matplotlib.pyplot as plt nap.nap_config.suppress_conversion_warnings = True diff --git a/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.py b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.py index b7168e33..ca9b167a 100644 --- a/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.py +++ b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.py @@ -71,20 +71,19 @@ # ## Combining basis transformations and GLM in a pipeline # Let's start by creating some toy data. -import nemos as nmo +import matplotlib.pyplot as plt import numpy as np import pandas as pd import scipy.stats -import matplotlib.pyplot as plt import seaborn as sns - -from sklearn.pipeline import Pipeline from sklearn.model_selection import GridSearchCV +from sklearn.pipeline import Pipeline + +import nemos as nmo # some helper plotting functions from nemos import _documentation_utils as doc_plots - # predictors, shape (n_samples, n_features) X = np.random.uniform(low=0, high=1, size=(1000, 1)) # observed counts, shape (n_samples,) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 3bf13b50..97c099a7 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -185,7 +185,6 @@ def _unpack_inputs(X: FeatureMatrix): A tuple of each individual input. """ - return (X[:, k] for k in range(X.shape[1])) def fit(self, X: FeatureMatrix, y=None): @@ -947,7 +946,8 @@ def __len__(self) -> int: def __mul__(self, other: (Basis, int)) -> MultiplicativeBasis: """ - Multiply two Basis objects together. + Multiply two Basis objects together or replicate the basis + by multiplying it with an integer. Parameters ---------- @@ -965,11 +965,13 @@ def __mul__(self, other: (Basis, int)) -> MultiplicativeBasis: if other <= 0: raise ValueError("Multiplier should be a non-negative integer!") result = self - for _ in range(other-1): + for _ in range(other - 1): result = result + self return result else: - raise TypeError("Basis can only be multiplied with another basis or an integer!") + raise TypeError( + "Basis can only be multiplied with another basis or an integer!" + ) def __pow__(self, exponent: int) -> MultiplicativeBasis: """Exponentiation of a Basis object. diff --git a/tests/test_basis.py b/tests/test_basis.py index 6e81d142..7bc3d797 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -3202,6 +3202,65 @@ def test_compute_features_input(self, eval_input): basis_obj = basis.MSplineBasis(5) + basis.MSplineBasis(5) basis_obj.compute_features(*eval_input) + @pytest.mark.parametrize("n_basis_a", [5, 6]) + @pytest.mark.parametrize("n_basis_b", [5, 6]) + @pytest.mark.parametrize("basis_a", list_all_basis_classes()) + @pytest.mark.parametrize("basis_b", list_all_basis_classes()) + def test_len( + self, n_basis_a, n_basis_b, basis_a, basis_b, + ): + """ + Test for __len__ of basis + """ + # define the two basis + basis_a_obj = self.instantiate_basis( + n_basis_a, basis_a, mode="eval" + ) + basis_b_obj = self.instantiate_basis( + n_basis_b, basis_b, mode="eval" + ) + + basis_obj = basis_a_obj + basis_b_obj + + assert hasattr(basis_obj, "__len__") + assert len(basis_a_obj) == basis_a_obj.n_basis_funcs + assert len(basis_b_obj) == basis_b_obj.n_basis_funcs + assert len(basis_obj) == basis_obj.n_basis_funcs + + @pytest.mark.parametrize("n", [1, 6]) + @pytest.mark.parametrize("basis", list_all_basis_classes()) + def test_basis_multiply_with_integer( + self, n, basis, + ): + """ + Test for __mul__ of basis with integer + """ + # define the two basis + basis_obj = self.instantiate_basis( + 5, basis, mode="eval" + ) + new_basis_obj = basis_obj * n + + assert new_basis_obj.n_basis_funcs == n * basis_obj.n_basis_funcs + + @pytest.mark.parametrize("basis", list_all_basis_classes()) + @pytest.mark.parametrize("n, expected", [ + (-2, pytest.raises(ValueError, match=r"Multiplier should be a non-negative integer!")), + ("6", pytest.raises(TypeError, match=r"Basis can only be multiplied with another basis or an integer!")) + ]) + def test_basis_multiply_errors( + self, basis, n, expected + ): + """ + Test for __mul__ of basis. raise errors + """ + # define the two basis + basis_obj = self.instantiate_basis( + 5, basis, mode="eval" + ) + with expected: + basis_obj * n + @pytest.mark.parametrize("n_basis_a", [5, 6]) @pytest.mark.parametrize("n_basis_b", [5, 6]) @pytest.mark.parametrize("sample_size", [10, 1000])