diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index 91865028..c2b5f51b 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -100,14 +100,11 @@ def _unpack_inputs(self, X: FeatureMatrix) -> List: """ n_samples = X.shape[0] - out = [] - cc = 0 - for i, bas in enumerate(self._list_components()): - n_input = self._n_basis_input_[i] - out.append( - np.reshape(X[:, cc : cc + n_input], (n_samples, *bas._input_shape_)) - ) - cc += n_input + out = [ + np.reshape(X[:, cc: cc + n_input], (n_samples, *bas._input_shape_)) + for i, (bas, n_input) in enumerate(zip(self._list_components(), self._n_basis_input_)) + for cc in [sum(self._n_basis_input_[:i])] + ] return out def fit(self, X: FeatureMatrix, y=None):