diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index 16c01e20..56c04a06 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -99,13 +99,13 @@ def _unpack_inputs(self, X: FeatureMatrix) -> List: """ n_samples = X.shape[0] - out = [ + out = ( np.reshape(X[:, cc : cc + n_input], (n_samples, *bas._input_shape_)) for i, (bas, n_input) in enumerate( zip(self._list_components(), self._n_basis_input_) ) for cc in [sum(self._n_basis_input_[:i])] - ] + ) return out def fit(self, X: FeatureMatrix, y=None):