Skip to content

Commit

Permalink
improved unpacking
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Dec 13, 2024
1 parent ac6323e commit 50548f0
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions src/nemos/basis/_transformer_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,11 @@ def _unpack_inputs(self, X: FeatureMatrix) -> List:
"""
n_samples = X.shape[0]
out = []
cc = 0
for i, bas in enumerate(self._list_components()):
n_input = self._n_basis_input_[i]
out.append(
np.reshape(X[:, cc : cc + n_input], (n_samples, *bas._input_shape_))
)
cc += n_input
out = [
np.reshape(X[:, cc: cc + n_input], (n_samples, *bas._input_shape_))
for i, (bas, n_input) in enumerate(zip(self._list_components(), self._n_basis_input_))
for cc in [sum(self._n_basis_input_[:i])]
]
return out

def fit(self, X: FeatureMatrix, y=None):
Expand Down

0 comments on commit 50548f0

Please sign in to comment.