diff --git a/pysaliency/utils/__init__.py b/pysaliency/utils/__init__.py index 88f1e59..316eb7b 100644 --- a/pysaliency/utils/__init__.py +++ b/pysaliency/utils/__init__.py @@ -28,7 +28,7 @@ def build_padded_2d_array(arrays, max_length=None, padding_value=np.nan): max_length = np.max([len(a) for a in arrays]) #output = np.ones((len(arrays), max_length), dtype=np.asarray(arrays[0]).dtype) - dtype = np.asarray(arrays[0]).dtype + dtype = np.asarray(arrays[0]).dtype if arrays else np.float64 if np.issubdtype(dtype, np.integer) and padding_value is np.nan: dtype = np.float64 diff --git a/pysaliency/utils/variable_length_array.py b/pysaliency/utils/variable_length_array.py index f31ee2e..e171e72 100644 --- a/pysaliency/utils/variable_length_array.py +++ b/pysaliency/utils/variable_length_array.py @@ -52,7 +52,7 @@ def __init__(self, data: Union[np.ndarray, List[list]], lengths: Optional[np.nda else: if not data.ndim >= 2: raise ValueError("If data is a numpy array, it has to be at least 2-dimensional") - if np.max(lengths) > data.shape[1]: + if lengths and np.max(lengths) > data.shape[1]: raise ValueError("The specified lengths are larger than the number of columns in the data array") else: @@ -63,7 +63,7 @@ def __init__(self, data: Union[np.ndarray, List[list]], lengths: Optional[np.nda if isinstance(data, np.ndarray): self._data = data else: - self._data = build_padded_2d_array(data, max_length=np.max(lengths)) + self._data = build_padded_2d_array(data, max_length=np.max(lengths) if lengths else 0) # max_len = np.max(lengths) # self._data = np.full((len(data), max_len), np.nan) @@ -96,11 +96,11 @@ def __getitem__(self, index): # new_data = self._data[index, :max_length] # return VariableLengthArray(new_data, new_lengths) - def copy(self): + def copy(self) -> 'VariableLengthArray': return VariableLengthArray(self._data.copy(), self.lengths.copy()) -def concatenate_variable_length_arrays(arrays: List[VariableLengthArray]): +def concatenate_variable_length_arrays(arrays: List[VariableLengthArray]) -> VariableLengthArray: """ Concatenate a list of VariableLengthArray objects along the first axis. diff --git a/tests/utils/test_variable_length_array.py b/tests/utils/test_variable_length_array.py index c364880..c339eed 100644 --- a/tests/utils/test_variable_length_array.py +++ b/tests/utils/test_variable_length_array.py @@ -174,4 +174,26 @@ def test_variable_length_array_concatenate(): expected = VariableLengthArray(concatenated_data) np.testing.assert_array_equal(concatenated_array._data, expected._data) - np.testing.assert_array_equal(concatenated_array.lengths, expected.lengths) \ No newline at end of file + np.testing.assert_array_equal(concatenated_array.lengths, expected.lengths) + + +def test_variable_length_array_zero_length_from_2d_array(): + data = np.empty((0, 0)) + lengths = np.array([]) + array = VariableLengthArray(data, lengths) + + assert len(array) == 0 + + assert np.array_equal(array._data, np.empty((0, 0))) + assert np.array_equal(array.lengths, np.array([])) + + +def test_variable_length_array_zero_length_from_list(): + data = [] + lengths = np.array([]) + array = VariableLengthArray(data, lengths) + + assert len(array) == 0 + + assert np.array_equal(array._data, np.empty((0, 0))) + assert np.array_equal(array.lengths, np.array([])) \ No newline at end of file