-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feature: VariableLengthArray for scanpath coordinates, histories etc
Signed-off-by: Matthias Kümmerer <[email protected]>
- Loading branch information
1 parent
245cda9
commit 3dbe8dc
Showing
4 changed files
with
265 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,3 +10,4 @@ | |
*.c | ||
*.so | ||
*.egg-info | ||
.vscode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from typing import Optional, Union | ||
|
||
import numpy as np | ||
|
||
from . import build_padded_2d_array | ||
|
||
|
||
class VariableLengthArray: | ||
""" | ||
Represents a variable length array. | ||
The following indexing operations are supported: | ||
- Accessing rows: array[i] | ||
- Accessing elements: array[i, j] where j can also be negative to get elements from the end of the row | ||
- Slicing: array[i:j, k] where k can also be negative to get elements from the end of each row | ||
Args: | ||
data (Union[np.ndarray, list[list]]): The data for the array. Can be either a numpy array or a list of lists. | ||
lengths (np.ndarray): The lengths of each row in the data array. | ||
Attributes: | ||
_data (np.ndarray): The internal data array with padded rows. | ||
lengths (np.ndarray): The lengths of each row in the data array. | ||
Methods: | ||
__len__(): Returns the number of rows in the array. | ||
__getitem__(index): Returns the value(s) at the specified index(es) in the array. | ||
""" | ||
|
||
def __init__(self, data: Union[np.ndarray, list[list]], lengths: Optional[np.ndarray] = None): | ||
""" | ||
Initialize the VariableLengthArray object. | ||
Args: | ||
data (Union[np.ndarray, list[list]]): The input data, which can be either a numpy array or a list of lists. | ||
lengths (np.ndarray): An array containing the lengths of each row in the data. | ||
Raises: | ||
ValueError: If the input data shape doesn't match the provided lengths. | ||
""" | ||
|
||
if lengths is not None: | ||
if len(data) != len(lengths): | ||
raise ValueError(f"The number of rows in the data array has to match the number of elements in lengths ({len(data)} != {len(lengths)})") | ||
|
||
if not isinstance(data, np.ndarray): | ||
for row, length in zip(data, lengths): | ||
if len(row) != length: | ||
raise ValueError(f"The length of row {row} does not match the specified length {length}") | ||
else: | ||
if not data.ndim >= 2: | ||
raise ValueError("If data is a numpy array, it has to be at least 2-dimensional") | ||
if np.max(lengths) > data.shape[1]: | ||
raise ValueError("The specified lengths are larger than the number of columns in the data array") | ||
|
||
else: | ||
if isinstance(data, np.ndarray): | ||
raise ValueError("If data is a numpy array, lengths must be provided") | ||
lengths = np.array([len(row) for row in data]) | ||
|
||
if isinstance(data, np.ndarray): | ||
self._data = data | ||
else: | ||
self._data = build_padded_2d_array(data, max_length=np.max(lengths)) | ||
|
||
# max_len = np.max(lengths) | ||
# self._data = np.full((len(data), max_len), np.nan) | ||
# for i, row in enumerate(data): | ||
# if len(row) < lengths[i]: | ||
# raise ValueError(f"Row {i} has fewer elements than specified in lengths ({len(row)} < {lengths[i]}") | ||
# self._data[i, :lengths[i]] = row[:lengths[i]] | ||
self.lengths = lengths | ||
|
||
def __len__(self): | ||
return len(self._data) | ||
|
||
def __getitem__(self, index): | ||
if isinstance(index, tuple): | ||
row_idx, col_idx = index | ||
if isinstance(row_idx, slice): | ||
if isinstance(col_idx, int): | ||
return np.array([self._data[i, :self.lengths[i]][col_idx] for i in range(*row_idx.indices(len(self._data)))]) | ||
elif isinstance(col_idx, slice): | ||
# does this work? | ||
return self._data[row_idx, :self.lengths[row_idx]][col_idx] | ||
else: | ||
return self._data[row_idx, :self.lengths[row_idx]][col_idx] | ||
elif isinstance(index, int): | ||
return self._data[index, :self.lengths[index]] | ||
else: | ||
return VariableLengthArray(self._data[index], self.lengths[index]) | ||
# new_lengths = self.lengths[index] | ||
# max_length = np.max(new_lengths) | ||
# new_data = self._data[index, :max_length] | ||
# return VariableLengthArray(new_data, new_lengths) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from pysaliency.utils import build_padded_2d_array | ||
from pysaliency.utils.variable_length_array import VariableLengthArray | ||
|
||
|
||
def test_variable_length_array_from_padded_array_basics(): | ||
# Test case 1 | ||
data = build_padded_2d_array([[1.0, 2, 3], [4, 5]]) | ||
lengths = np.array([3, 2]) | ||
array = VariableLengthArray(data, lengths) | ||
|
||
assert len(array) == 2 | ||
|
||
rows = list(array) | ||
assert np.array_equal(rows[0], np.array([1, 2, 3])) | ||
assert np.array_equal(rows[1], np.array([4, 5])) | ||
|
||
def test_variable_length_array_from_padded_array(): | ||
# Test case 1 | ||
data = build_padded_2d_array([[1.0, 2, 3], [4, 5]]) | ||
lengths = np.array([3, 2]) | ||
array = VariableLengthArray(data, lengths) | ||
|
||
# test accessing rows | ||
assert np.array_equal(array[0], np.array([1, 2, 3])) | ||
assert np.array_equal(array[1], np.array([4, 5])) | ||
|
||
# test accessing elements | ||
assert np.array_equal(array[0, 1], 2) | ||
|
||
# acessing elements outside the length of the row should raise an IndexError | ||
with pytest.raises(IndexError): | ||
array[1, 2] | ||
|
||
# test slicing | ||
assert np.array_equal(array[:, 0], [1, 4]) | ||
|
||
# test slicing with negative indices | ||
assert np.array_equal(array[:, -1], [3, 5]) | ||
|
||
|
||
|
||
# Test case 2 | ||
data = build_padded_2d_array([[1.0, 2], [3, 4, 5]]) | ||
lengths = np.array([2, 3]) | ||
array = VariableLengthArray(data, lengths) | ||
|
||
# test accessing rows | ||
assert np.array_equal(array[0], np.array([1, 2])) | ||
assert np.array_equal(array[1], np.array([3, 4, 5])) | ||
|
||
# test accessing elements | ||
assert np.array_equal(array[0, 1], 2) | ||
assert np.array_equal(array[1, 2], 5) | ||
|
||
# acessing elements outside the length of the row should raise an IndexError | ||
with pytest.raises(IndexError): | ||
array[1, 3] | ||
|
||
# test slicing | ||
assert np.array_equal(array[:, 0], [1, 3]) | ||
|
||
# test slicing with negative indices | ||
assert np.array_equal(array[:, -1], [2, 5]) | ||
|
||
|
||
def test_variable_length_array_slicing_with_slices(): | ||
data = build_padded_2d_array([[1.0, 2, 3], [4, 5], [6, 7, 8, 9]]) | ||
lengths = np.array([3, 2, 4]) | ||
array = VariableLengthArray(data, lengths) | ||
|
||
sub_array = array[1:] | ||
assert isinstance(sub_array, VariableLengthArray) | ||
assert len(sub_array) == 2 | ||
np.testing.assert_array_equal(sub_array._data, data[1:]) | ||
np.testing.assert_array_equal(sub_array[0], np.array([4, 5])) | ||
np.testing.assert_array_equal(sub_array[1], np.array([6, 7, 8, 9])) | ||
|
||
sub_array = array[:2] | ||
assert isinstance(sub_array, VariableLengthArray) | ||
assert len(sub_array) == 2 | ||
np.testing.assert_array_equal(sub_array._data, data[:2]) # one length item is cut off | ||
np.testing.assert_array_equal(sub_array[0], np.array([1, 2, 3])) | ||
np.testing.assert_array_equal(sub_array[1], np.array([4, 5])) | ||
|
||
|
||
def test_variable_length_array_slicing_with_indices(): | ||
data = build_padded_2d_array([[1.0, 2, 3], [4, 5], [6, 7, 8, 9]]) | ||
lengths = np.array([3, 2, 4]) | ||
array = VariableLengthArray(data, lengths) | ||
|
||
sub_array = array[[0, 2]] | ||
assert isinstance(sub_array, VariableLengthArray) | ||
assert len(sub_array) == 2 | ||
np.testing.assert_array_equal(sub_array._data, data[[0, 2]]) | ||
np.testing.assert_array_equal(sub_array[0], np.array([1, 2, 3])) | ||
np.testing.assert_array_equal(sub_array[1], np.array([6, 7, 8, 9])) | ||
|
||
|
||
def test_variable_length_array_slicing_with_mask(): | ||
data = build_padded_2d_array([[1.0, 2, 3], [4, 5], [6, 7, 8, 9]]) | ||
lengths = np.array([3, 2, 4]) | ||
array = VariableLengthArray(data, lengths) | ||
|
||
sub_array = array[[True, False, True]] | ||
assert isinstance(sub_array, VariableLengthArray) | ||
assert len(sub_array) == 2 | ||
np.testing.assert_array_equal(sub_array._data, data[[0, 2]]) | ||
np.testing.assert_array_equal(sub_array[0], np.array([1, 2, 3])) | ||
np.testing.assert_array_equal(sub_array[1], np.array([6, 7, 8, 9])) | ||
|
||
|
||
def test_variable_length_array_from_list_of_arrays(): | ||
# Test case 1 | ||
data = [[1, 2, 3], [4, 5]] | ||
lengths = np.array([3, 2]) | ||
array = VariableLengthArray(data, lengths) | ||
|
||
np.testing.assert_array_equal(array._data, np.array([[1, 2, 3], [4, 5, np.nan]])) | ||
|
||
|
||
def test_variable_length_array_from_list_of_arrays_without_specified_lengths(): | ||
data = [[1, 2, 3], [4, 5]] | ||
lengths = np.array([3, 2]) | ||
array = VariableLengthArray(data) | ||
|
||
np.testing.assert_array_equal(array._data, np.array([[1, 2, 3], [4, 5, np.nan]])) | ||
np.testing.assert_array_equal(array.lengths, lengths) | ||
|
||
|
||
def test_variable_length_array_inconsistent_lengths(): | ||
# consistent case | ||
data = [[1, 2, 3], [4]] | ||
lengths = np.array([3, 1]) | ||
|
||
VariableLengthArray(data, lengths) | ||
|
||
# inconsistent case | ||
data = [[1, 2, 3], [4]] | ||
lengths = np.array([3, 2]) | ||
|
||
with pytest.raises(ValueError): | ||
VariableLengthArray(data, lengths) |