Skip to content

Commit

Permalink
add a generator
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Dec 15, 2024
1 parent 95d0cd4 commit 4234c51
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/nemos/basis/_transformer_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import copy
from functools import wraps
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, Generator

import numpy as np

Expand Down Expand Up @@ -102,7 +102,7 @@ def basis(self):
def basis(self, basis):
self._basis = basis

def _unpack_inputs(self, X: FeatureMatrix) -> List:
def _unpack_inputs(self, X: FeatureMatrix) -> Generator:
"""Unpack inputs.
Unpack horizontally stacked inputs using slicing. This works gracefully with ``pynapple``,
Expand All @@ -120,13 +120,13 @@ def _unpack_inputs(self, X: FeatureMatrix) -> List:
"""
n_samples = X.shape[0]
out = [
out = (
np.reshape(X[:, cc : cc + n_input], (n_samples, *bas._input_shape_))
for i, (bas, n_input) in enumerate(
zip(self._list_components(), self._n_basis_input_)
)
for cc in [sum(self._n_basis_input_[:i])]
]
)
return out

def fit(self, X: FeatureMatrix, y=None):
Expand Down

0 comments on commit 4234c51

Please sign in to comment.