From 4234c517c46d7a9a305bd63c8856672290a31e3b Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Sun, 15 Dec 2024 12:16:14 -0500 Subject: [PATCH] add a generator --- src/nemos/basis/_transformer_basis.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index 850aed64..5a76b23f 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -2,7 +2,7 @@ import copy from functools import wraps -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, Generator import numpy as np @@ -102,7 +102,7 @@ def basis(self): def basis(self, basis): self._basis = basis - def _unpack_inputs(self, X: FeatureMatrix) -> List: + def _unpack_inputs(self, X: FeatureMatrix) -> Generator: """Unpack inputs. Unpack horizontally stacked inputs using slicing. This works gracefully with ``pynapple``, @@ -120,13 +120,13 @@ def _unpack_inputs(self, X: FeatureMatrix) -> List: """ n_samples = X.shape[0] - out = [ + out = ( np.reshape(X[:, cc : cc + n_input], (n_samples, *bas._input_shape_)) for i, (bas, n_input) in enumerate( zip(self._list_components(), self._n_basis_input_) ) for cc in [sum(self._n_basis_input_[:i])] - ] + ) return out def fit(self, X: FeatureMatrix, y=None):