From 50548f0ff43ca34ce12555ce6f5a7c9a64b5a120 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 13 Dec 2024 12:32:28 -0500 Subject: [PATCH] improved unpacking --- src/nemos/basis/_transformer_basis.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index 91865028..c2b5f51b 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -100,14 +100,11 @@ def _unpack_inputs(self, X: FeatureMatrix) -> List: """ n_samples = X.shape[0] - out = [] - cc = 0 - for i, bas in enumerate(self._list_components()): - n_input = self._n_basis_input_[i] - out.append( - np.reshape(X[:, cc : cc + n_input], (n_samples, *bas._input_shape_)) - ) - cc += n_input + out = [ + np.reshape(X[:, cc: cc + n_input], (n_samples, *bas._input_shape_)) + for i, (bas, n_input) in enumerate(zip(self._list_components(), self._n_basis_input_)) + for cc in [sum(self._n_basis_input_[:i])] + ] return out def fit(self, X: FeatureMatrix, y=None):