Skip to content

Commit

Permalink
Merge branch 'set_shape_basis_method' into improve_transformer_api
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Dec 16, 2024
2 parents ad88294 + 8bff762 commit 1112dfe
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 17 deletions.
12 changes: 4 additions & 8 deletions src/nemos/basis/_basis_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,11 +307,7 @@ def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor):
structure: a single (X, y) pair for the transformer, a number of time series for the Basis.
"""
if self.kernel_ is None:
raise ValueError(
"You must call `setup_basis` before `_compute_features`! "
"Convolution kernel is not set."
)
self._check_has_kernel()
# before calling the convolve, check that the input matches
# the expectation. We can check xi[0] only, since convolution
# is applied at the end of the recursion on the 1D basis, ensuring len(xi) == 1.
Expand Down Expand Up @@ -457,8 +453,8 @@ def _check_convolution_kwargs(conv_kwargs: dict):
def _check_has_kernel(self) -> None:
"""Check that the kernel is pre-computed."""
if self.kernel_ is None:
raise ValueError(
"You must call `_set_kernel` before `_compute_features` for Conv basis."
raise RuntimeError(
"You must call `setup_basis` before `_compute_features` for Conv basis."
)


Expand Down Expand Up @@ -517,7 +513,7 @@ def __init__(self, basis1: Basis, basis2: Basis):
*(bas2._input_shape_ for bas2 in basis2._iterate_over_components()),
)
# if all bases where set, then set input for composition.
set_bases = (s is not None for s in shapes)
set_bases = [s is not None for s in shapes]

if all(set_bases):
# pass down the input shapes
Expand Down
87 changes: 78 additions & 9 deletions tests/test_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,15 @@ def method(self):
pass

assert CustomSubClass().method.__doc__ == "My extra text.\nMy custom method."
with pytest.raises(AttributeError, match="CustomClass has no attribute"):

class CustomSubClass2(CustomClass):
@custom_add_docstring("unknown", cls=CustomClass)
def method(self):
"""My custom method."""
pass

CustomSubClass2()


@pytest.mark.parametrize(
Expand Down Expand Up @@ -241,18 +250,34 @@ def test_expected_output_compute_features(basis_instance, super_class):
),
OrthExponentialBasis,
),
(
basis.OrthExponentialConv(
10, decay_rates=np.arange(1, 11), window_size=12, label="a"
)
* basis.RaisedCosineLogConv(10, window_size=11, label="b"),
OrthExponentialBasis,
),
(
basis.OrthExponentialConv(
10, decay_rates=np.arange(1, 11), window_size=12, label="a"
)
+ basis.RaisedCosineLogConv(10, window_size=11, label="b"),
OrthExponentialBasis,
),
],
)
def test_expected_output_split_by_feature(basis_instance, super_class):
x = super_class.compute_features(basis_instance, np.linspace(0, 1, 100))
inp = [np.linspace(0, 1, 100)] * basis_instance._n_input_dimensionality
x = super_class.compute_features(basis_instance, *inp)
xdict = super_class.split_by_feature(basis_instance, x)
xxdict = basis_instance.split_by_feature(x)
assert xdict.keys() == xxdict.keys()
xx = xxdict["label"]
x = xdict["label"]
nans = np.isnan(x.sum(axis=(1, 2)))
assert np.all(np.isnan(xx[nans]))
np.testing.assert_array_equal(xx[~nans], x[~nans])
for k in xdict.keys():
xx = xxdict[k]
x = xdict[k]
nans = np.isnan(x.sum(axis=(1, 2)))
assert np.all(np.isnan(xx[nans]))
np.testing.assert_array_equal(xx[~nans], x[~nans])


@pytest.mark.parametrize(
Expand Down Expand Up @@ -1236,7 +1261,7 @@ def test_transform_fails(self, cls):
n_basis_funcs=5, window_size=5, **extra_decay_rates(cls["conv"], 5)
)
with pytest.raises(
ValueError, match="You must call `setup_basis` before `_compute_features`"
RuntimeError, match="You must call `setup_basis` before `_compute_features`"
):
bas._compute_features(np.linspace(0, 1, 10))

Expand Down Expand Up @@ -1579,6 +1604,15 @@ def test_minimum_number_of_basis_required_is_matched(
n_basis_funcs=n_basis_funcs, order=order, **kwargs
)
basis_obj.compute_features(np.linspace(0, 1, 10))

# test the setter valuerror
if (order > 1) & (n_basis_funcs > 1):
basis_obj = self.cls[mode](n_basis_funcs=20, order=order, **kwargs)
with pytest.raises(
ValueError,
match=rf"{self.cls[mode].__name__} `order` parameter cannot be larger than",
):
basis_obj.n_basis_funcs = n_basis_funcs
else:
basis_obj = self.cls[mode](
n_basis_funcs=n_basis_funcs, order=order, **kwargs
Expand Down Expand Up @@ -2210,6 +2244,41 @@ def test_number_of_required_inputs_compute_features(
with expectation:
basis_obj.compute_features(*inputs)

@pytest.mark.parametrize("basis_a", list_all_basis_classes())
@pytest.mark.parametrize("basis_b", list_all_basis_classes())
@pytest.mark.parametrize("n_basis_a", [5])
@pytest.mark.parametrize("n_basis_b", [6])
@pytest.mark.parametrize("window_size", [10])
def test_warn_partial_setup(
self,
n_basis_a,
n_basis_b,
basis_a,
basis_b,
window_size,
basis_class_specific_params,
):
basis_a_obj = self.instantiate_basis(
n_basis_a, basis_a, basis_class_specific_params, window_size=window_size
)
basis_b_obj = self.instantiate_basis(
n_basis_b, basis_b, basis_class_specific_params, window_size=window_size
)

basis_a_obj.set_input_shape(*([1] * basis_a_obj._n_input_dimensionality))
with pytest.warns(UserWarning, match="Only some of the basis where"):
basis_a_obj + basis_b_obj

# check that if both set addition is fine
basis_b_obj.set_input_shape(*([1] * basis_b_obj._n_input_dimensionality))
basis_a_obj + basis_b_obj

basis_a_obj = self.instantiate_basis(
n_basis_a, basis_a, basis_class_specific_params, window_size=window_size
)
with pytest.warns(UserWarning, match="Only some of the basis where"):
basis_a_obj + basis_b_obj

@pytest.mark.parametrize("sample_size", [11, 20])
@pytest.mark.parametrize("basis_a", list_all_basis_classes())
@pytest.mark.parametrize("basis_b", list_all_basis_classes())
Expand Down Expand Up @@ -2661,7 +2730,7 @@ def test_transform_fails(
context = does_not_raise()
else:
context = pytest.raises(
ValueError,
RuntimeError,
match="You must call `setup_basis` before `_compute_features`",
)
with context:
Expand Down Expand Up @@ -3589,7 +3658,7 @@ def test_transform_fails(
context = does_not_raise()
else:
context = pytest.raises(
ValueError,
RuntimeError,
match="You must call `setup_basis` before `_compute_features`",
)
with context:
Expand Down

0 comments on commit 1112dfe

Please sign in to comment.