From d813d3769466ba95d24efdf88f70accd4720069d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arne=20K=C3=BCderle?= Date: Wed, 8 Feb 2023 15:17:05 +0100 Subject: [PATCH] Bugfixes and tests for StrideEvenZuptDetector --- gaitmap/zupt_detection/_base.py | 7 +- .../_stride_event_zupt_detector.py | 41 +++++---- .../test_stride_event_zupt_detector.py | 90 +++++++++++++++++++ 3 files changed, 116 insertions(+), 22 deletions(-) create mode 100644 tests/test_zupt_detection/test_stride_event_zupt_detector.py diff --git a/gaitmap/zupt_detection/_base.py b/gaitmap/zupt_detection/_base.py index 3733cf16..ab7b0248 100644 --- a/gaitmap/zupt_detection/_base.py +++ b/gaitmap/zupt_detection/_base.py @@ -1,7 +1,7 @@ import numpy as np import pandas as pd -from gaitmap.utils.array_handling import bool_array_to_start_end_array +from gaitmap.utils.array_handling import bool_array_to_start_end_array, start_end_array_to_bool_array from gaitmap.utils.datatype_helper import SingleSensorData @@ -39,7 +39,4 @@ class RegionZuptDetectorMixin: @property def per_sample_zupts_(self) -> np.ndarray: """Get a bool array of length data with all Zupts as True.""" - zupts = np.zeros(self.data.shape[0], dtype=bool) - for _, row in self.zupts_.iterrows(): - zupts[row["start"] : row["end"]] = True - return zupts + return start_end_array_to_bool_array(self.zupts_.to_numpy(), self.data.shape[0]) diff --git a/gaitmap/zupt_detection/_stride_event_zupt_detector.py b/gaitmap/zupt_detection/_stride_event_zupt_detector.py index 80d8a403..bc079144 100644 --- a/gaitmap/zupt_detection/_stride_event_zupt_detector.py +++ b/gaitmap/zupt_detection/_stride_event_zupt_detector.py @@ -5,6 +5,7 @@ from typing_extensions import Self from gaitmap.base import BaseZuptDetector +from gaitmap.utils.array_handling import merge_intervals from gaitmap.utils.datatype_helper import ( SingleSensorData, SingleSensorStrideList, @@ -44,18 +45,15 @@ class StrideEventZuptDetector(BaseZuptDetector, RegionZuptDetectorMixin): per_sample_zupts_ A bool array with length `len(data)`. If the value is `True` for a sample, it is part of a static region. - window_length_samples_ - The internally calculated window length in samples. - This might be helpful for debugging. - window_overlap_samples_ - The internally calculated window overlap in samples. - This might be helpful for debugging. + half_region_size_samples_ + The actual half region size in samples calculated using the data sampling rate. """ half_region_size_s: float + half_region_size_samples_: int - def __init__(self, half_region_size_s: float): + def __init__(self, half_region_size_s: float = 0.05): self.half_region_size_s = half_region_size_s def detect( @@ -87,25 +85,34 @@ def detect( self.data = data self.stride_event_list = stride_event_list self.sampling_rate_hz = sampling_rate_hz - is_single_sensor_data(self.data, check_acc=True, check_gyr=True, frame="any", raise_exception=True) + + if self.half_region_size_s < 0: + raise ValueError("The half region size must be >= 0") + + # We don't need the data. We still check it, as we need its length for the per_sample_zupts_ attribute. + # This means, we need to make at least sure that the data is somewhat valid. + is_single_sensor_data(self.data, check_acc=False, check_gyr=False, frame="any", raise_exception=True) try: - is_single_sensor_stride_list(self.stride_event_list, "min_vel", raise_exception=True) + is_single_sensor_stride_list( + self.stride_event_list, "min_vel", check_additional_cols=("min_vel",), raise_exception=True + ) except ValidationError as e: raise ValidationError( - "For the `StrideEventZuptDetector` a proper stride_event_list of the `min_vel` type" + "For the `StrideEventZuptDetector` a proper stride_event_list of the `min_vel` type is required." ) from e - region_size_samples = int(np.round(self.half_region_size_s * sampling_rate_hz)) + self.half_region_size_samples_ = int(np.round(self.half_region_size_s * sampling_rate_hz)) # In a min_vel stride list, all starts and all ends are min_vel events. all_min_vel_events = np.unique(np.concatenate([self.stride_event_list["start"], self.stride_event_list["end"]])) - self.zupts_ = pd.DataFrame( - { - "start": np.clip(all_min_vel_events - region_size_samples, 0, None), - "end": np.clip(all_min_vel_events + region_size_samples, None, self.data.shape[0]), - } - ).astype(int) + start_ends = np.empty((len(all_min_vel_events), 2), dtype=int) + start_ends[:, 0] = np.clip(all_min_vel_events - self.half_region_size_samples_, 0, None) + start_ends[:, 1] = np.clip(all_min_vel_events + self.half_region_size_samples_ + 1, None, self.data.shape[0]) + self.zupts_ = pd.DataFrame(merge_intervals(start_ends), columns=["start", "end"]) + # This is required, because otherwise, edge cases at the start or end of the data could lead to zero-length + # ZUPTs. + self.zupts_ = self.zupts_.loc[self.zupts_["start"] < self.zupts_["end"]] return self diff --git a/tests/test_zupt_detection/test_stride_event_zupt_detector.py b/tests/test_zupt_detection/test_stride_event_zupt_detector.py new file mode 100644 index 00000000..5e22753d --- /dev/null +++ b/tests/test_zupt_detection/test_stride_event_zupt_detector.py @@ -0,0 +1,90 @@ +import pandas as pd +import pytest +from pandas._testing import assert_frame_equal + +from gaitmap.utils.consts import SF_COLS +from gaitmap.utils.exceptions import ValidationError +from gaitmap.zupt_detection import StrideEventZuptDetector +from tests.mixins.test_algorithm_mixin import TestAlgorithmMixin + + +class TestMetaFunctionalityStrideEventZuptDetector(TestAlgorithmMixin): + __test__ = True + algorithm_class = StrideEventZuptDetector + + @pytest.fixture() + def after_action_instance(self, healthy_example_imu_data): + data_left = healthy_example_imu_data["left_sensor"].iloc[:10] + return StrideEventZuptDetector().detect( + data_left, + sampling_rate_hz=204.8, + stride_event_list=pd.DataFrame( + [[0, 5, 0]], columns=["start", "end", "min_vel"], index=pd.Series([0], name="s_id") + ), + ) + + +class TestStrideEventZuptDetector: + def test_improper_stride_list(self): + with pytest.raises(ValidationError): + StrideEventZuptDetector().detect( + pd.DataFrame([[0, 0, 0, 0, 0, 0]] * 10, columns=SF_COLS), + sampling_rate_hz=1, + # No min_vel column + stride_event_list=pd.DataFrame([[0, 5]], columns=["start", "end"]), + ) + + def test_region_0(self): + stride_event_list = pd.DataFrame( + [[0, 7, 0], [5, 10, 5]], columns=["start", "end", "min_vel"], index=pd.Series([0, 1], name="s_id") + ) + data = pd.DataFrame([[0, 0, 0, 0, 0, 0]] * 11, columns=SF_COLS) + + zupts = ( + StrideEventZuptDetector(half_region_size_s=0) + .detect(data=data, stride_event_list=stride_event_list, sampling_rate_hz=1) + .zupts_ + ) + + assert_frame_equal(zupts, pd.DataFrame([[0, 1], [5, 6], [7, 8], [10, 11]], columns=["start", "end"])) + + def test_edge_case(self): + """We test what happens if the zupt is exactly the first or last sample of the data or outside the range.""" + stride_event_list = pd.DataFrame( + [[0, 10, 0], [10, 15, 10]], columns=["start", "end", "min_vel"], index=pd.Series([0, 1], name="s_id") + ) + data = pd.DataFrame([[0, 0, 0, 0, 0, 0]] * 10, columns=SF_COLS) + + detector = StrideEventZuptDetector(half_region_size_s=0).detect( + data=data, stride_event_list=stride_event_list, sampling_rate_hz=1 + ) + zupts = detector.zupts_ + + assert_frame_equal(zupts, pd.DataFrame([[0, 1]], columns=["start", "end"])) + assert detector.half_region_size_samples_ == 0 + + def test_with_overlap(self): + stride_event_list = pd.DataFrame( + [[0, 7, 0], [5, 10, 5]], columns=["start", "end", "min_vel"], index=pd.Series([0, 1], name="s_id") + ) + data = pd.DataFrame([[0, 0, 0, 0, 0, 0]] * 11, columns=SF_COLS) + + detector = StrideEventZuptDetector(half_region_size_s=2).detect( + data=data, stride_event_list=stride_event_list, sampling_rate_hz=1 + ) + zupts = detector.zupts_ + assert_frame_equal(zupts, pd.DataFrame([[0, 11]], columns=["start", "end"])) + assert detector.half_region_size_samples_ == 2 + + def test_simple(self): + stride_event_list = pd.DataFrame( + [[0, 5, 0], [10, 15, 10]], columns=["start", "end", "min_vel"], index=pd.Series([0, 1], name="s_id") + ) + data = pd.DataFrame([[0, 0, 0, 0, 0, 0]] * 20, columns=SF_COLS) + + detector = StrideEventZuptDetector(half_region_size_s=0.5).detect( + data=data, stride_event_list=stride_event_list, sampling_rate_hz=2 + ) + zupts = detector.zupts_ + assert_frame_equal(zupts, pd.DataFrame([[0, 2], [4, 7], [9, 12], [14, 17]], columns=["start", "end"])) + assert detector.half_region_size_samples_ == 1