From 3d91df13062d928675adffc19cf71c5353920d59 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Sun, 15 Dec 2024 12:39:47 -0500 Subject: [PATCH] return generator for iterating over components --- src/nemos/basis/_basis_mixin.py | 24 ++++++++++++++---------- src/nemos/basis/_transformer_basis.py | 2 +- tests/test_basis.py | 11 ++++++----- 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index d0c80f81..21ade27d 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -7,7 +7,8 @@ import inspect import warnings from functools import wraps -from typing import TYPE_CHECKING, Optional, Tuple, Union +from itertools import chain +from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union import numpy as np from numpy.typing import ArrayLike, NDArray @@ -88,8 +89,8 @@ def __sklearn_clone__(self) -> Basis: setattr(klass, attr_name, getattr(self, attr_name)) return klass - def _list_components(self): - """List all basis components. + def _iterate_over_components(self) -> Generator: + """Return a generator that iterates over all basis components. For atomic bases, the list is just [self]. @@ -98,7 +99,7 @@ def _list_components(self): A list with the basis components. """ - return [self] + return (x for x in [self]) def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): """ @@ -512,8 +513,8 @@ def __init__(self, basis1: Basis, basis2: Basis): self.basis2._parent = self shapes = ( - *(bas1._input_shape_ for bas1 in basis1._list_components()), - *(bas2._input_shape_ for bas2 in basis2._list_components()), + *(bas1._input_shape_ for bas1 in basis1._iterate_over_components()), + *(bas2._input_shape_ for bas2 in basis2._iterate_over_components()), ) # if all bases where set, then set input for composition. set_bases = (s is not None for s in shapes) @@ -599,17 +600,20 @@ def basis2(self): def basis2(self, bas: Basis): self._basis2 = bas - def _list_components(self): - """List all basis components. + def _iterate_over_components(self): + """Return a generator that iterates over all basis components. - Reimplements the default behavior by iteratively calling _list_components of the + Reimplements the default behavior by iteratively calling _iterate_over_components of the elements. Returns ------- A list with all 1d basis components. """ - return self._basis1._list_components() + self._basis2._list_components() + return chain( + self._basis1._iterate_over_components(), + self._basis2._iterate_over_components(), + ) @set_input_shape_state def __sklearn_clone__(self) -> Basis: diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index 67063e28..4420eb40 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -102,7 +102,7 @@ def _unpack_inputs(self, X: FeatureMatrix) -> Generator: out = ( np.reshape(X[:, cc : cc + n_input], (n_samples, *bas._input_shape_)) for i, (bas, n_input) in enumerate( - zip(self._list_components(), self._n_basis_input_) + zip(self._iterate_over_components(), self._n_basis_input_) ) for cc in [sum(self._n_basis_input_[:i])] ) diff --git a/tests/test_basis.py b/tests/test_basis.py index 770bce35..1031bc14 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -1331,14 +1331,14 @@ def test_set_input_value_types(self, inp_shape, expectation, cls): @pytest.mark.parametrize( "mode, kwargs", [("eval", {}), ("conv", {"window_size": 6})] ) - def test_list_component(self, mode, kwargs, cls): + def test_iterate_over_component(self, mode, kwargs, cls): basis_obj = cls[mode]( n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5), ) - out = basis_obj._list_components() + out = tuple(basis_obj._iterate_over_components()) assert len(out) == 1 assert id(out[0]) == id(basis_obj) @@ -1974,7 +1974,9 @@ class TestAdditiveBasis(CombinedBasis): @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) - def test_list_component(self, basis_a, basis_b, basis_class_specific_params): + def test_iterate_over_component( + self, basis_a, basis_b, basis_class_specific_params + ): basis_a_obj = self.instantiate_basis( 5, basis_a, basis_class_specific_params, window_size=10 ) @@ -1982,8 +1984,7 @@ def test_list_component(self, basis_a, basis_b, basis_class_specific_params): 6, basis_b, basis_class_specific_params, window_size=10 ) add = basis_a_obj + basis_b_obj - out = add._list_components() - + out = tuple(add._iterate_over_components()) assert len(out) == add._n_input_dimensionality def get_ids(bas):