Skip to content

Commit

Permalink
Fix: VariableLengthArray failed for zero length
Browse files Browse the repository at this point in the history
Signed-off-by: Matthias Kümmerer <[email protected]>
  • Loading branch information
matthias-k committed Mar 28, 2024
1 parent 76ad069 commit 9828783
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pysaliency/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def build_padded_2d_array(arrays, max_length=None, padding_value=np.nan):
max_length = np.max([len(a) for a in arrays])

#output = np.ones((len(arrays), max_length), dtype=np.asarray(arrays[0]).dtype)
dtype = np.asarray(arrays[0]).dtype
dtype = np.asarray(arrays[0]).dtype if arrays else np.float64

if np.issubdtype(dtype, np.integer) and padding_value is np.nan:
dtype = np.float64
Expand Down
8 changes: 4 additions & 4 deletions pysaliency/utils/variable_length_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, data: Union[np.ndarray, List[list]], lengths: Optional[np.nda
else:
if not data.ndim >= 2:
raise ValueError("If data is a numpy array, it has to be at least 2-dimensional")
if np.max(lengths) > data.shape[1]:
if lengths and np.max(lengths) > data.shape[1]:
raise ValueError("The specified lengths are larger than the number of columns in the data array")

else:
Expand All @@ -63,7 +63,7 @@ def __init__(self, data: Union[np.ndarray, List[list]], lengths: Optional[np.nda
if isinstance(data, np.ndarray):
self._data = data
else:
self._data = build_padded_2d_array(data, max_length=np.max(lengths))
self._data = build_padded_2d_array(data, max_length=np.max(lengths) if lengths else 0)

# max_len = np.max(lengths)
# self._data = np.full((len(data), max_len), np.nan)
Expand Down Expand Up @@ -96,11 +96,11 @@ def __getitem__(self, index):
# new_data = self._data[index, :max_length]
# return VariableLengthArray(new_data, new_lengths)

def copy(self):
def copy(self) -> 'VariableLengthArray':
return VariableLengthArray(self._data.copy(), self.lengths.copy())


def concatenate_variable_length_arrays(arrays: List[VariableLengthArray]):
def concatenate_variable_length_arrays(arrays: List[VariableLengthArray]) -> VariableLengthArray:
"""
Concatenate a list of VariableLengthArray objects along the first axis.
Expand Down
24 changes: 23 additions & 1 deletion tests/utils/test_variable_length_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,26 @@ def test_variable_length_array_concatenate():
expected = VariableLengthArray(concatenated_data)

np.testing.assert_array_equal(concatenated_array._data, expected._data)
np.testing.assert_array_equal(concatenated_array.lengths, expected.lengths)
np.testing.assert_array_equal(concatenated_array.lengths, expected.lengths)


def test_variable_length_array_zero_length_from_2d_array():
data = np.empty((0, 0))
lengths = np.array([])
array = VariableLengthArray(data, lengths)

assert len(array) == 0

assert np.array_equal(array._data, np.empty((0, 0)))
assert np.array_equal(array.lengths, np.array([]))


def test_variable_length_array_zero_length_from_list():
data = []
lengths = np.array([])
array = VariableLengthArray(data, lengths)

assert len(array) == 0

assert np.array_equal(array._data, np.empty((0, 0)))
assert np.array_equal(array.lengths, np.array([]))

0 comments on commit 9828783

Please sign in to comment.