Skip to content

Commit

Permalink
Bugfixes and tests for StrideEvenZuptDetector
Browse files Browse the repository at this point in the history
  • Loading branch information
AKuederle committed Feb 10, 2023
1 parent 1472c14 commit d813d37
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 22 deletions.
7 changes: 2 additions & 5 deletions gaitmap/zupt_detection/_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pandas as pd

from gaitmap.utils.array_handling import bool_array_to_start_end_array
from gaitmap.utils.array_handling import bool_array_to_start_end_array, start_end_array_to_bool_array
from gaitmap.utils.datatype_helper import SingleSensorData


Expand Down Expand Up @@ -39,7 +39,4 @@ class RegionZuptDetectorMixin:
@property
def per_sample_zupts_(self) -> np.ndarray:
"""Get a bool array of length data with all Zupts as True."""
zupts = np.zeros(self.data.shape[0], dtype=bool)
for _, row in self.zupts_.iterrows():
zupts[row["start"] : row["end"]] = True
return zupts
return start_end_array_to_bool_array(self.zupts_.to_numpy(), self.data.shape[0])
41 changes: 24 additions & 17 deletions gaitmap/zupt_detection/_stride_event_zupt_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing_extensions import Self

from gaitmap.base import BaseZuptDetector
from gaitmap.utils.array_handling import merge_intervals
from gaitmap.utils.datatype_helper import (
SingleSensorData,
SingleSensorStrideList,
Expand Down Expand Up @@ -44,18 +45,15 @@ class StrideEventZuptDetector(BaseZuptDetector, RegionZuptDetectorMixin):
per_sample_zupts_
A bool array with length `len(data)`.
If the value is `True` for a sample, it is part of a static region.
window_length_samples_
The internally calculated window length in samples.
This might be helpful for debugging.
window_overlap_samples_
The internally calculated window overlap in samples.
This might be helpful for debugging.
half_region_size_samples_
The actual half region size in samples calculated using the data sampling rate.
"""

half_region_size_s: float
half_region_size_samples_: int

def __init__(self, half_region_size_s: float):
def __init__(self, half_region_size_s: float = 0.05):
self.half_region_size_s = half_region_size_s

def detect(
Expand Down Expand Up @@ -87,25 +85,34 @@ def detect(
self.data = data
self.stride_event_list = stride_event_list
self.sampling_rate_hz = sampling_rate_hz
is_single_sensor_data(self.data, check_acc=True, check_gyr=True, frame="any", raise_exception=True)

if self.half_region_size_s < 0:
raise ValueError("The half region size must be >= 0")

# We don't need the data. We still check it, as we need its length for the per_sample_zupts_ attribute.
# This means, we need to make at least sure that the data is somewhat valid.
is_single_sensor_data(self.data, check_acc=False, check_gyr=False, frame="any", raise_exception=True)

try:
is_single_sensor_stride_list(self.stride_event_list, "min_vel", raise_exception=True)
is_single_sensor_stride_list(
self.stride_event_list, "min_vel", check_additional_cols=("min_vel",), raise_exception=True
)
except ValidationError as e:
raise ValidationError(
"For the `StrideEventZuptDetector` a proper stride_event_list of the `min_vel` type"
"For the `StrideEventZuptDetector` a proper stride_event_list of the `min_vel` type is required."
) from e

region_size_samples = int(np.round(self.half_region_size_s * sampling_rate_hz))
self.half_region_size_samples_ = int(np.round(self.half_region_size_s * sampling_rate_hz))

# In a min_vel stride list, all starts and all ends are min_vel events.
all_min_vel_events = np.unique(np.concatenate([self.stride_event_list["start"], self.stride_event_list["end"]]))

self.zupts_ = pd.DataFrame(
{
"start": np.clip(all_min_vel_events - region_size_samples, 0, None),
"end": np.clip(all_min_vel_events + region_size_samples, None, self.data.shape[0]),
}
).astype(int)
start_ends = np.empty((len(all_min_vel_events), 2), dtype=int)
start_ends[:, 0] = np.clip(all_min_vel_events - self.half_region_size_samples_, 0, None)
start_ends[:, 1] = np.clip(all_min_vel_events + self.half_region_size_samples_ + 1, None, self.data.shape[0])
self.zupts_ = pd.DataFrame(merge_intervals(start_ends), columns=["start", "end"])
# This is required, because otherwise, edge cases at the start or end of the data could lead to zero-length
# ZUPTs.
self.zupts_ = self.zupts_.loc[self.zupts_["start"] < self.zupts_["end"]]

return self
90 changes: 90 additions & 0 deletions tests/test_zupt_detection/test_stride_event_zupt_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import pandas as pd
import pytest
from pandas._testing import assert_frame_equal

from gaitmap.utils.consts import SF_COLS
from gaitmap.utils.exceptions import ValidationError
from gaitmap.zupt_detection import StrideEventZuptDetector
from tests.mixins.test_algorithm_mixin import TestAlgorithmMixin


class TestMetaFunctionalityStrideEventZuptDetector(TestAlgorithmMixin):
__test__ = True
algorithm_class = StrideEventZuptDetector

@pytest.fixture()
def after_action_instance(self, healthy_example_imu_data):
data_left = healthy_example_imu_data["left_sensor"].iloc[:10]
return StrideEventZuptDetector().detect(
data_left,
sampling_rate_hz=204.8,
stride_event_list=pd.DataFrame(
[[0, 5, 0]], columns=["start", "end", "min_vel"], index=pd.Series([0], name="s_id")
),
)


class TestStrideEventZuptDetector:
def test_improper_stride_list(self):
with pytest.raises(ValidationError):
StrideEventZuptDetector().detect(
pd.DataFrame([[0, 0, 0, 0, 0, 0]] * 10, columns=SF_COLS),
sampling_rate_hz=1,
# No min_vel column
stride_event_list=pd.DataFrame([[0, 5]], columns=["start", "end"]),
)

def test_region_0(self):
stride_event_list = pd.DataFrame(
[[0, 7, 0], [5, 10, 5]], columns=["start", "end", "min_vel"], index=pd.Series([0, 1], name="s_id")
)
data = pd.DataFrame([[0, 0, 0, 0, 0, 0]] * 11, columns=SF_COLS)

zupts = (
StrideEventZuptDetector(half_region_size_s=0)
.detect(data=data, stride_event_list=stride_event_list, sampling_rate_hz=1)
.zupts_
)

assert_frame_equal(zupts, pd.DataFrame([[0, 1], [5, 6], [7, 8], [10, 11]], columns=["start", "end"]))

def test_edge_case(self):
"""We test what happens if the zupt is exactly the first or last sample of the data or outside the range."""
stride_event_list = pd.DataFrame(
[[0, 10, 0], [10, 15, 10]], columns=["start", "end", "min_vel"], index=pd.Series([0, 1], name="s_id")
)
data = pd.DataFrame([[0, 0, 0, 0, 0, 0]] * 10, columns=SF_COLS)

detector = StrideEventZuptDetector(half_region_size_s=0).detect(
data=data, stride_event_list=stride_event_list, sampling_rate_hz=1
)
zupts = detector.zupts_

assert_frame_equal(zupts, pd.DataFrame([[0, 1]], columns=["start", "end"]))
assert detector.half_region_size_samples_ == 0

def test_with_overlap(self):
stride_event_list = pd.DataFrame(
[[0, 7, 0], [5, 10, 5]], columns=["start", "end", "min_vel"], index=pd.Series([0, 1], name="s_id")
)
data = pd.DataFrame([[0, 0, 0, 0, 0, 0]] * 11, columns=SF_COLS)

detector = StrideEventZuptDetector(half_region_size_s=2).detect(
data=data, stride_event_list=stride_event_list, sampling_rate_hz=1
)
zupts = detector.zupts_
assert_frame_equal(zupts, pd.DataFrame([[0, 11]], columns=["start", "end"]))
assert detector.half_region_size_samples_ == 2

def test_simple(self):
stride_event_list = pd.DataFrame(
[[0, 5, 0], [10, 15, 10]], columns=["start", "end", "min_vel"], index=pd.Series([0, 1], name="s_id")
)
data = pd.DataFrame([[0, 0, 0, 0, 0, 0]] * 20, columns=SF_COLS)

detector = StrideEventZuptDetector(half_region_size_s=0.5).detect(
data=data, stride_event_list=stride_event_list, sampling_rate_hz=2
)
zupts = detector.zupts_
assert_frame_equal(zupts, pd.DataFrame([[0, 2], [4, 7], [9, 12], [14, 17]], columns=["start", "end"]))
assert detector.half_region_size_samples_ == 1

0 comments on commit d813d37

Please sign in to comment.