From c02294fedc87da7474894ba9fda0e8da3ed5aecc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arne=20K=C3=BCderle?= Date: Wed, 8 Feb 2023 15:49:10 +0100 Subject: [PATCH] Further adapted start_end_array_to_bool_array to fit Zupt detector usecase --- CHANGELOG.md | 4 ++++ gaitmap/utils/array_handling.py | 15 ++++++++------- gaitmap/zupt_detection/_combo_zupt_detector.py | 3 ++- tests/test_utils/test_array_handling.py | 8 +++++++- 4 files changed, 21 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index eb0bab27..feba68df 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -95,6 +95,10 @@ project. - The util method `start_end_array_to_bool_array` now assumes that the end index of all regions is inclusive. This enables roundtrip conversion with the `bool_array_to_start_end_array` method and is in line with the definitions used for strides, ROIs, and ZUPTs in gaitmap. + Further, the method now supports to output arrays that are shorter than the largest input index. + Before, this resulted in an error. + Both changes might require some user facing code changes, if this function is used. + However, as it was not used internally, it is likely that no one was using it anyway. (https://github.com/mad-lab-fau/gaitmap/pull/14) diff --git a/gaitmap/utils/array_handling.py b/gaitmap/utils/array_handling.py index e0c31f31..d9faa073 100644 --- a/gaitmap/utils/array_handling.py +++ b/gaitmap/utils/array_handling.py @@ -145,8 +145,9 @@ def start_end_array_to_bool_array(start_end_array: np.ndarray, pad_to_length: in This is in line with the definitions of stride and roi lists in gaitmap. pad_to_length: int - define length of resulting array if None is given the array will have the length of the last element of the - initial start_end_array + Define the length of the resulting array. + If None, the array will have the length of the largest index. + Otherwise, the final array will either be padded with False or truncated to the specified length. Returns ------- @@ -167,11 +168,11 @@ def start_end_array_to_bool_array(start_end_array: np.ndarray, pad_to_length: in """ start_end_array = np.atleast_2d(start_end_array) - n_elements = start_end_array.max() - - if pad_to_length: - if pad_to_length <= n_elements: - raise ValueError("Padding length must be larger than last element of start end array!") + if pad_to_length is None: + n_elements = start_end_array.max() + else: + if pad_to_length < 0: + raise ValueError("pad_to_length must be positive!") n_elements = pad_to_length bool_array = np.zeros(n_elements) diff --git a/gaitmap/zupt_detection/_combo_zupt_detector.py b/gaitmap/zupt_detection/_combo_zupt_detector.py index 50a6c087..2ccf19ee 100644 --- a/gaitmap/zupt_detection/_combo_zupt_detector.py +++ b/gaitmap/zupt_detection/_combo_zupt_detector.py @@ -66,7 +66,8 @@ def detect( Returns ------- self - The fitted instance + The class instance with all result attributes populated + """ if not self.detectors: raise ValueError("No detectors have been set.") diff --git a/tests/test_utils/test_array_handling.py b/tests/test_utils/test_array_handling.py index 75168e88..69268054 100644 --- a/tests/test_utils/test_array_handling.py +++ b/tests/test_utils/test_array_handling.py @@ -289,7 +289,13 @@ def test_invalid_padding(self): with pytest.raises(ValueError) as e: start_end_array_to_bool_array(input_array, pad_to_length=-1) - assert "Padding length must be larger than" in str(e) + assert "pad_to_length must be positive" in str(e) + + def test_short_padding(self): + input_array = np.array([[2, 3], [5, 9]]) + output_array = start_end_array_to_bool_array(input_array, pad_to_length=7) + expected_output = np.array([0, 0, 1, 0, 0, 1, 1]).astype(bool) + assert_array_equal(expected_output, output_array) def test_correct_output_dtype(self): input_array = np.array([[2, 3], [5, 9]])