Skip to content

Commit

Permalink
improved coverage transformer basis
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Dec 16, 2024
1 parent 33fa274 commit ad88294
Showing 2 changed files with 40 additions and 20 deletions.
2 changes: 1 addition & 1 deletion src/nemos/basis/_transformer_basis.py
Original file line number Diff line number Diff line change
@@ -211,8 +211,8 @@ def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix:
>>> # Transform basis
>>> feature_transformed = transformer.transform(X)
"""
self._check_input(X, y)
self._check_initialized(self._basis)
self._check_input(X, y)
# transpose does not work with pynapple
# can't use func(*X.T) to unwrap
return self._basis._compute_features(*self._unpack_inputs(X))
58 changes: 39 additions & 19 deletions tests/test_transformer_basis.py
Original file line number Diff line number Diff line change
@@ -3,15 +3,14 @@

import numpy as np
import pytest

from conftest import CombinedBasis, list_all_basis_classes
from sklearn.base import clone as sk_clone
from sklearn.pipeline import Pipeline

import nemos as nmo
from nemos import basis
from nemos._inspect_utils import list_abstract_methods, get_subclass_methods
from nemos.basis import AdditiveBasis, MultiplicativeBasis
from nemos._inspect_utils import get_subclass_methods, list_abstract_methods
from nemos.basis import AdditiveBasis, MSplineConv, MultiplicativeBasis


@pytest.mark.parametrize(
@@ -699,12 +698,13 @@ def test_transformer_fit_transform_input_struct(
@pytest.mark.parametrize(
"inp",
[
np.random.randn(
0.1
* np.random.randn(
100,
),
np.random.randn(100, 1),
np.random.randn(100, 2),
np.random.randn(100, 1, 2),
0.1 * np.random.randn(100, 1),
0.1 * np.random.randn(100, 2),
0.1 * np.random.randn(100, 1, 2),
],
)
def test_transformer_in_pipeline(basis_cls, inp, basis_class_specific_params):
@@ -756,6 +756,7 @@ def test_transformer_in_pipeline(basis_cls, inp, basis_class_specific_params):
model.fit(X, y)
np.testing.assert_allclose(pipe["glm"].coef_, model.coef_)


@pytest.mark.parametrize(
"basis_cls",
list_all_basis_classes(),
@@ -766,7 +767,7 @@ def test_initialization(basis_cls, basis_class_specific_params):
)
transformer = bas.to_transformer()
with pytest.raises(RuntimeError, match="Cannot apply TransformerBasis"):
transformer.fit(np.ones((100, )))
transformer.fit(np.ones((100,)))

with pytest.raises(RuntimeError, match="Cannot apply TransformerBasis"):
transformer.transform(np.ones((100,)))
@@ -784,7 +785,9 @@ def test_basis_setter(basis_cls, basis_class_specific_params):
5, basis_cls, basis_class_specific_params, window_size=10
)

bas2 = CombinedBasis().instantiate_basis(7, basis_cls, basis_class_specific_params, window_size=10)
bas2 = CombinedBasis().instantiate_basis(
7, basis_cls, basis_class_specific_params, window_size=10
)
transformer = bas.to_transformer()
transformer.basis = bas2
assert transformer.basis.n_basis_funcs == bas2.n_basis_funcs
@@ -811,17 +814,20 @@ def test_eetstate(basis_cls, basis_class_specific_params):
bas = CombinedBasis().instantiate_basis(
5, basis_cls, basis_class_specific_params, window_size=10
)
bas2 = CombinedBasis().instantiate_basis(7, basis_cls, basis_class_specific_params, window_size=10)
bas2 = CombinedBasis().instantiate_basis(
7, basis_cls, basis_class_specific_params, window_size=10
)
transformer = bas.to_transformer()
state = {"_basis": bas2}
transformer.__setstate__(state)
assert transformer.basis == bas2
assert transformer.basis == bas2


@pytest.mark.parametrize(
"basis_cls",
list_all_basis_classes(),
)
def test_eetstate(basis_cls, basis_class_specific_params):
def test_getstate(basis_cls, basis_class_specific_params):
bas = CombinedBasis().instantiate_basis(
5, basis_cls, basis_class_specific_params, window_size=10
)
@@ -835,25 +841,35 @@ def test_eetstate(basis_cls, basis_class_specific_params):

# check all reimplemented methods
dict_reimplemented_method = get_subclass_methods(basis_cls)
for meth in dict_abst_method:
for meth in dict_reimplemented_method:
assert meth[0] in lst

# check that it is a trnasformer
for meth in ["fit", "transform", "fit_transform"]:
assert meth in lst


@pytest.mark.parametrize(
"basis_cls",
list_all_basis_classes(),
)
@pytest.mark.parametrize(
"inp, expectation",
[
(np.random.randn(10, 2), pytest.raises(ValueError, match="Input mismatch: expected \d inputs")),
(np.random.randn(10, 3, 1), pytest.raises(ValueError, match="X must be 2-dimensional")),
({1: np.random.randn(10, 3)}, pytest.raises(ValueError, match="The input must be a 2-dimensional array")),
(
np.random.randn(10, 2),
pytest.raises(ValueError, match="Input mismatch: expected \d inputs"),
),
(
np.random.randn(10, 3, 1),
pytest.raises(ValueError, match="X must be 2-dimensional"),
),
(
{1: np.random.randn(10, 3)},
pytest.raises(ValueError, match="The input must be a 2-dimensional array"),
),
(np.random.randn(10, 3), does_not_raise()),
]
],
)
@pytest.mark.parametrize("method", ["fit", "transform", "fit_transform"])
def test_check_input(inp, expectation, basis_cls, basis_class_specific_params, method):
@@ -863,11 +879,15 @@ def test_check_input(inp, expectation, basis_cls, basis_class_specific_params, m
# set kernels
bas._set_input_independent_states()
# set input shape
transformer = bas.to_transformer().set_input_shape(*([3] * bas._n_input_dimensionality))
transformer = bas.to_transformer().set_input_shape(
*([3] * bas._n_input_dimensionality)
)
if isinstance(bas, (AdditiveBasis, MultiplicativeBasis)):
if hasattr(inp, "ndim"):
ndim = inp.ndim
inp = np.concatenate([inp.reshape(inp.shape[0], -1)] * bas._n_input_dimensionality, axis=1)
inp = np.concatenate(
[inp.reshape(inp.shape[0], -1)] * bas._n_input_dimensionality, axis=1
)
if ndim == 3:
inp = inp[..., np.newaxis]

0 comments on commit ad88294

Please sign in to comment.