From 3dbe8dcef93127d513d5dfef348caee88def2930 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20K=C3=BCmmerer?= Date: Thu, 28 Mar 2024 00:16:01 +0100 Subject: [PATCH] Feature: VariableLengthArray for scanpath coordinates, histories etc MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Matthias Kümmerer --- .gitignore | 1 + pysaliency/{utils.py => utils/__init__.py} | 40 +++--- pysaliency/utils/variable_length_array.py | 97 ++++++++++++++ tests/utils/test_variable_length_array.py | 145 +++++++++++++++++++++ 4 files changed, 265 insertions(+), 18 deletions(-) rename pysaliency/{utils.py => utils/__init__.py} (96%) create mode 100644 pysaliency/utils/variable_length_array.py create mode 100644 tests/utils/test_variable_length_array.py diff --git a/.gitignore b/.gitignore index 220fd40..a951049 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ *.c *.so *.egg-info +.vscode diff --git a/pysaliency/utils.py b/pysaliency/utils/__init__.py similarity index 96% rename from pysaliency/utils.py rename to pysaliency/utils/__init__.py index 407d4e3..88f1e59 100644 --- a/pysaliency/utils.py +++ b/pysaliency/utils/__init__.py @@ -1,35 +1,39 @@ -from __future__ import print_function, absolute_import, division -from collections.abc import Sequence, MutableMapping -from itertools import chain -from glob import iglob -from contextlib import contextmanager, ExitStack -import warnings as _warnings -import os as _os -import sys as _sys -import os +from __future__ import absolute_import, division, print_function + import hashlib -from functools import partial -import warnings +import os +import os as _os import shutil -from itertools import count, filterfalse, groupby import subprocess as sp +import sys as _sys +import warnings +import warnings as _warnings +from collections.abc import MutableMapping, Sequence +from contextlib import ExitStack, contextmanager +from functools import partial +from glob import iglob +from itertools import chain, count, filterfalse, groupby from tempfile import mkdtemp +import deprecation import numpy as np -from scipy.interpolate import griddata - +import requests from boltons.cacheutils import LRU -import deprecation +from scipy.interpolate import griddata from tqdm import tqdm -import requests def build_padded_2d_array(arrays, max_length=None, padding_value=np.nan): if max_length is None: max_length = np.max([len(a) for a in arrays]) - output = np.ones((len(arrays), max_length), dtype=np.asarray(arrays[0]).dtype) - output *= padding_value + #output = np.ones((len(arrays), max_length), dtype=np.asarray(arrays[0]).dtype) + dtype = np.asarray(arrays[0]).dtype + + if np.issubdtype(dtype, np.integer) and padding_value is np.nan: + dtype = np.float64 + + output = np.full((len(arrays), max_length), fill_value=padding_value, dtype=dtype) for i, array in enumerate(arrays): output[i, :len(array)] = array diff --git a/pysaliency/utils/variable_length_array.py b/pysaliency/utils/variable_length_array.py new file mode 100644 index 0000000..1e01fa2 --- /dev/null +++ b/pysaliency/utils/variable_length_array.py @@ -0,0 +1,97 @@ +from typing import Optional, Union + +import numpy as np + +from . import build_padded_2d_array + + +class VariableLengthArray: + """ + Represents a variable length array. + + The following indexing operations are supported: + - Accessing rows: array[i] + - Accessing elements: array[i, j] where j can also be negative to get elements from the end of the row + - Slicing: array[i:j, k] where k can also be negative to get elements from the end of each row + + + Args: + data (Union[np.ndarray, list[list]]): The data for the array. Can be either a numpy array or a list of lists. + lengths (np.ndarray): The lengths of each row in the data array. + + Attributes: + _data (np.ndarray): The internal data array with padded rows. + lengths (np.ndarray): The lengths of each row in the data array. + + Methods: + __len__(): Returns the number of rows in the array. + __getitem__(index): Returns the value(s) at the specified index(es) in the array. + """ + + def __init__(self, data: Union[np.ndarray, list[list]], lengths: Optional[np.ndarray] = None): + """ + Initialize the VariableLengthArray object. + + Args: + data (Union[np.ndarray, list[list]]): The input data, which can be either a numpy array or a list of lists. + lengths (np.ndarray): An array containing the lengths of each row in the data. + + Raises: + ValueError: If the input data shape doesn't match the provided lengths. + + """ + + if lengths is not None: + if len(data) != len(lengths): + raise ValueError(f"The number of rows in the data array has to match the number of elements in lengths ({len(data)} != {len(lengths)})") + + if not isinstance(data, np.ndarray): + for row, length in zip(data, lengths): + if len(row) != length: + raise ValueError(f"The length of row {row} does not match the specified length {length}") + else: + if not data.ndim >= 2: + raise ValueError("If data is a numpy array, it has to be at least 2-dimensional") + if np.max(lengths) > data.shape[1]: + raise ValueError("The specified lengths are larger than the number of columns in the data array") + + else: + if isinstance(data, np.ndarray): + raise ValueError("If data is a numpy array, lengths must be provided") + lengths = np.array([len(row) for row in data]) + + if isinstance(data, np.ndarray): + self._data = data + else: + self._data = build_padded_2d_array(data, max_length=np.max(lengths)) + + # max_len = np.max(lengths) + # self._data = np.full((len(data), max_len), np.nan) + # for i, row in enumerate(data): + # if len(row) < lengths[i]: + # raise ValueError(f"Row {i} has fewer elements than specified in lengths ({len(row)} < {lengths[i]}") + # self._data[i, :lengths[i]] = row[:lengths[i]] + self.lengths = lengths + + def __len__(self): + return len(self._data) + + def __getitem__(self, index): + if isinstance(index, tuple): + row_idx, col_idx = index + if isinstance(row_idx, slice): + if isinstance(col_idx, int): + return np.array([self._data[i, :self.lengths[i]][col_idx] for i in range(*row_idx.indices(len(self._data)))]) + elif isinstance(col_idx, slice): + # does this work? + return self._data[row_idx, :self.lengths[row_idx]][col_idx] + else: + return self._data[row_idx, :self.lengths[row_idx]][col_idx] + elif isinstance(index, int): + return self._data[index, :self.lengths[index]] + else: + return VariableLengthArray(self._data[index], self.lengths[index]) + # new_lengths = self.lengths[index] + # max_length = np.max(new_lengths) + # new_data = self._data[index, :max_length] + # return VariableLengthArray(new_data, new_lengths) \ No newline at end of file diff --git a/tests/utils/test_variable_length_array.py b/tests/utils/test_variable_length_array.py new file mode 100644 index 0000000..cc03b63 --- /dev/null +++ b/tests/utils/test_variable_length_array.py @@ -0,0 +1,145 @@ +import numpy as np +import pytest + +from pysaliency.utils import build_padded_2d_array +from pysaliency.utils.variable_length_array import VariableLengthArray + + +def test_variable_length_array_from_padded_array_basics(): + # Test case 1 + data = build_padded_2d_array([[1.0, 2, 3], [4, 5]]) + lengths = np.array([3, 2]) + array = VariableLengthArray(data, lengths) + + assert len(array) == 2 + + rows = list(array) + assert np.array_equal(rows[0], np.array([1, 2, 3])) + assert np.array_equal(rows[1], np.array([4, 5])) + +def test_variable_length_array_from_padded_array(): + # Test case 1 + data = build_padded_2d_array([[1.0, 2, 3], [4, 5]]) + lengths = np.array([3, 2]) + array = VariableLengthArray(data, lengths) + + # test accessing rows + assert np.array_equal(array[0], np.array([1, 2, 3])) + assert np.array_equal(array[1], np.array([4, 5])) + + # test accessing elements + assert np.array_equal(array[0, 1], 2) + + # acessing elements outside the length of the row should raise an IndexError + with pytest.raises(IndexError): + array[1, 2] + + # test slicing + assert np.array_equal(array[:, 0], [1, 4]) + + # test slicing with negative indices + assert np.array_equal(array[:, -1], [3, 5]) + + + + # Test case 2 + data = build_padded_2d_array([[1.0, 2], [3, 4, 5]]) + lengths = np.array([2, 3]) + array = VariableLengthArray(data, lengths) + + # test accessing rows + assert np.array_equal(array[0], np.array([1, 2])) + assert np.array_equal(array[1], np.array([3, 4, 5])) + + # test accessing elements + assert np.array_equal(array[0, 1], 2) + assert np.array_equal(array[1, 2], 5) + + # acessing elements outside the length of the row should raise an IndexError + with pytest.raises(IndexError): + array[1, 3] + + # test slicing + assert np.array_equal(array[:, 0], [1, 3]) + + # test slicing with negative indices + assert np.array_equal(array[:, -1], [2, 5]) + + +def test_variable_length_array_slicing_with_slices(): + data = build_padded_2d_array([[1.0, 2, 3], [4, 5], [6, 7, 8, 9]]) + lengths = np.array([3, 2, 4]) + array = VariableLengthArray(data, lengths) + + sub_array = array[1:] + assert isinstance(sub_array, VariableLengthArray) + assert len(sub_array) == 2 + np.testing.assert_array_equal(sub_array._data, data[1:]) + np.testing.assert_array_equal(sub_array[0], np.array([4, 5])) + np.testing.assert_array_equal(sub_array[1], np.array([6, 7, 8, 9])) + + sub_array = array[:2] + assert isinstance(sub_array, VariableLengthArray) + assert len(sub_array) == 2 + np.testing.assert_array_equal(sub_array._data, data[:2]) # one length item is cut off + np.testing.assert_array_equal(sub_array[0], np.array([1, 2, 3])) + np.testing.assert_array_equal(sub_array[1], np.array([4, 5])) + + +def test_variable_length_array_slicing_with_indices(): + data = build_padded_2d_array([[1.0, 2, 3], [4, 5], [6, 7, 8, 9]]) + lengths = np.array([3, 2, 4]) + array = VariableLengthArray(data, lengths) + + sub_array = array[[0, 2]] + assert isinstance(sub_array, VariableLengthArray) + assert len(sub_array) == 2 + np.testing.assert_array_equal(sub_array._data, data[[0, 2]]) + np.testing.assert_array_equal(sub_array[0], np.array([1, 2, 3])) + np.testing.assert_array_equal(sub_array[1], np.array([6, 7, 8, 9])) + + +def test_variable_length_array_slicing_with_mask(): + data = build_padded_2d_array([[1.0, 2, 3], [4, 5], [6, 7, 8, 9]]) + lengths = np.array([3, 2, 4]) + array = VariableLengthArray(data, lengths) + + sub_array = array[[True, False, True]] + assert isinstance(sub_array, VariableLengthArray) + assert len(sub_array) == 2 + np.testing.assert_array_equal(sub_array._data, data[[0, 2]]) + np.testing.assert_array_equal(sub_array[0], np.array([1, 2, 3])) + np.testing.assert_array_equal(sub_array[1], np.array([6, 7, 8, 9])) + + +def test_variable_length_array_from_list_of_arrays(): + # Test case 1 + data = [[1, 2, 3], [4, 5]] + lengths = np.array([3, 2]) + array = VariableLengthArray(data, lengths) + + np.testing.assert_array_equal(array._data, np.array([[1, 2, 3], [4, 5, np.nan]])) + + +def test_variable_length_array_from_list_of_arrays_without_specified_lengths(): + data = [[1, 2, 3], [4, 5]] + lengths = np.array([3, 2]) + array = VariableLengthArray(data) + + np.testing.assert_array_equal(array._data, np.array([[1, 2, 3], [4, 5, np.nan]])) + np.testing.assert_array_equal(array.lengths, lengths) + + +def test_variable_length_array_inconsistent_lengths(): + # consistent case + data = [[1, 2, 3], [4]] + lengths = np.array([3, 1]) + + VariableLengthArray(data, lengths) + + # inconsistent case + data = [[1, 2, 3], [4]] + lengths = np.array([3, 2]) + + with pytest.raises(ValueError): + VariableLengthArray(data, lengths) \ No newline at end of file