Skip to content

Commit

Permalink
return generator for iterating over components
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Dec 15, 2024
1 parent 5998a2d commit 3d91df1
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 16 deletions.
24 changes: 14 additions & 10 deletions src/nemos/basis/_basis_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import inspect
import warnings
from functools import wraps
from typing import TYPE_CHECKING, Optional, Tuple, Union
from itertools import chain
from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union

import numpy as np
from numpy.typing import ArrayLike, NDArray
Expand Down Expand Up @@ -88,8 +89,8 @@ def __sklearn_clone__(self) -> Basis:
setattr(klass, attr_name, getattr(self, attr_name))
return klass

def _list_components(self):
"""List all basis components.
def _iterate_over_components(self) -> Generator:
"""Return a generator that iterates over all basis components.
For atomic bases, the list is just [self].
Expand All @@ -98,7 +99,7 @@ def _list_components(self):
A list with the basis components.
"""
return [self]
return (x for x in [self])

def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
"""
Expand Down Expand Up @@ -512,8 +513,8 @@ def __init__(self, basis1: Basis, basis2: Basis):
self.basis2._parent = self

shapes = (
*(bas1._input_shape_ for bas1 in basis1._list_components()),
*(bas2._input_shape_ for bas2 in basis2._list_components()),
*(bas1._input_shape_ for bas1 in basis1._iterate_over_components()),
*(bas2._input_shape_ for bas2 in basis2._iterate_over_components()),
)
# if all bases where set, then set input for composition.
set_bases = (s is not None for s in shapes)
Expand Down Expand Up @@ -599,17 +600,20 @@ def basis2(self):
def basis2(self, bas: Basis):
self._basis2 = bas

def _list_components(self):
"""List all basis components.
def _iterate_over_components(self):
"""Return a generator that iterates over all basis components.
Reimplements the default behavior by iteratively calling _list_components of the
Reimplements the default behavior by iteratively calling _iterate_over_components of the
elements.
Returns
-------
A list with all 1d basis components.
"""
return self._basis1._list_components() + self._basis2._list_components()
return chain(
self._basis1._iterate_over_components(),
self._basis2._iterate_over_components(),
)

@set_input_shape_state
def __sklearn_clone__(self) -> Basis:
Expand Down
2 changes: 1 addition & 1 deletion src/nemos/basis/_transformer_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _unpack_inputs(self, X: FeatureMatrix) -> Generator:
out = (
np.reshape(X[:, cc : cc + n_input], (n_samples, *bas._input_shape_))
for i, (bas, n_input) in enumerate(
zip(self._list_components(), self._n_basis_input_)
zip(self._iterate_over_components(), self._n_basis_input_)
)
for cc in [sum(self._n_basis_input_[:i])]
)
Expand Down
11 changes: 6 additions & 5 deletions tests/test_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,14 +1331,14 @@ def test_set_input_value_types(self, inp_shape, expectation, cls):
@pytest.mark.parametrize(
"mode, kwargs", [("eval", {}), ("conv", {"window_size": 6})]
)
def test_list_component(self, mode, kwargs, cls):
def test_iterate_over_component(self, mode, kwargs, cls):
basis_obj = cls[mode](
n_basis_funcs=5,
**kwargs,
**extra_decay_rates(cls[mode], 5),
)

out = basis_obj._list_components()
out = tuple(basis_obj._iterate_over_components())
assert len(out) == 1
assert id(out[0]) == id(basis_obj)

Expand Down Expand Up @@ -1974,16 +1974,17 @@ class TestAdditiveBasis(CombinedBasis):

@pytest.mark.parametrize("basis_a", list_all_basis_classes())
@pytest.mark.parametrize("basis_b", list_all_basis_classes())
def test_list_component(self, basis_a, basis_b, basis_class_specific_params):
def test_iterate_over_component(
self, basis_a, basis_b, basis_class_specific_params
):
basis_a_obj = self.instantiate_basis(
5, basis_a, basis_class_specific_params, window_size=10
)
basis_b_obj = self.instantiate_basis(
6, basis_b, basis_class_specific_params, window_size=10
)
add = basis_a_obj + basis_b_obj
out = add._list_components()

out = tuple(add._iterate_over_components())
assert len(out) == add._n_input_dimensionality

def get_ids(bas):
Expand Down

0 comments on commit 3d91df1

Please sign in to comment.