Skip to content

Commit

Permalink
Adding tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gviejo committed Oct 10, 2024
1 parent c4f3c53 commit a8c8e36
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 13 deletions.
5 changes: 3 additions & 2 deletions docs/how_to_guide/plot_04_population_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
"""

import jax.numpy as jnp
import nemos as nmo
import numpy as np
import matplotlib.pyplot as plt
import numpy as np

import nemos as nmo

np.random.seed(123)

Expand Down
5 changes: 3 additions & 2 deletions docs/how_to_guide/plot_05_batch_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
"""

import matplotlib.pyplot as plt
import numpy as np
import pynapple as nap

import nemos as nmo
import numpy as np
import matplotlib.pyplot as plt

nap.nap_config.suppress_conversion_warnings = True

Expand Down
9 changes: 4 additions & 5 deletions docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,19 @@
# ## Combining basis transformations and GLM in a pipeline
# Let's start by creating some toy data.

import nemos as nmo
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.stats
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import Pipeline

import nemos as nmo

# some helper plotting functions
from nemos import _documentation_utils as doc_plots


# predictors, shape (n_samples, n_features)
X = np.random.uniform(low=0, high=1, size=(1000, 1))
# observed counts, shape (n_samples,)
Expand Down
10 changes: 6 additions & 4 deletions src/nemos/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def _unpack_inputs(X: FeatureMatrix):
A tuple of each individual input.
"""

return (X[:, k] for k in range(X.shape[1]))

def fit(self, X: FeatureMatrix, y=None):
Expand Down Expand Up @@ -947,7 +946,8 @@ def __len__(self) -> int:

def __mul__(self, other: (Basis, int)) -> MultiplicativeBasis:
"""
Multiply two Basis objects together.
Multiply two Basis objects together or replicate the basis
by multiplying it with an integer.
Parameters
----------
Expand All @@ -965,11 +965,13 @@ def __mul__(self, other: (Basis, int)) -> MultiplicativeBasis:
if other <= 0:
raise ValueError("Multiplier should be a non-negative integer!")
result = self
for _ in range(other-1):
for _ in range(other - 1):
result = result + self
return result
else:
raise TypeError("Basis can only be multiplied with another basis or an integer!")
raise TypeError(
"Basis can only be multiplied with another basis or an integer!"
)

def __pow__(self, exponent: int) -> MultiplicativeBasis:
"""Exponentiation of a Basis object.
Expand Down
59 changes: 59 additions & 0 deletions tests/test_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3202,6 +3202,65 @@ def test_compute_features_input(self, eval_input):
basis_obj = basis.MSplineBasis(5) + basis.MSplineBasis(5)
basis_obj.compute_features(*eval_input)

@pytest.mark.parametrize("n_basis_a", [5, 6])
@pytest.mark.parametrize("n_basis_b", [5, 6])
@pytest.mark.parametrize("basis_a", list_all_basis_classes())
@pytest.mark.parametrize("basis_b", list_all_basis_classes())
def test_len(
self, n_basis_a, n_basis_b, basis_a, basis_b,
):
"""
Test for __len__ of basis
"""
# define the two basis
basis_a_obj = self.instantiate_basis(
n_basis_a, basis_a, mode="eval"
)
basis_b_obj = self.instantiate_basis(
n_basis_b, basis_b, mode="eval"
)

basis_obj = basis_a_obj + basis_b_obj

assert hasattr(basis_obj, "__len__")
assert len(basis_a_obj) == basis_a_obj.n_basis_funcs
assert len(basis_b_obj) == basis_b_obj.n_basis_funcs
assert len(basis_obj) == basis_obj.n_basis_funcs

@pytest.mark.parametrize("n", [1, 6])
@pytest.mark.parametrize("basis", list_all_basis_classes())
def test_basis_multiply_with_integer(
self, n, basis,
):
"""
Test for __mul__ of basis with integer
"""
# define the two basis
basis_obj = self.instantiate_basis(
5, basis, mode="eval"
)
new_basis_obj = basis_obj * n

assert new_basis_obj.n_basis_funcs == n * basis_obj.n_basis_funcs

@pytest.mark.parametrize("basis", list_all_basis_classes())
@pytest.mark.parametrize("n, expected", [
(-2, pytest.raises(ValueError, match=r"Multiplier should be a non-negative integer!")),
("6", pytest.raises(TypeError, match=r"Basis can only be multiplied with another basis or an integer!"))
])
def test_basis_multiply_errors(
self, basis, n, expected
):
"""
Test for __mul__ of basis. raise errors
"""
# define the two basis
basis_obj = self.instantiate_basis(
5, basis, mode="eval"
)
with expected:
basis_obj * n

@pytest.mark.parametrize("n_basis_a", [5, 6])
@pytest.mark.parametrize("n_basis_b", [5, 6])
@pytest.mark.parametrize("sample_size", [10, 1000])
Expand Down

0 comments on commit a8c8e36

Please sign in to comment.