Skip to content

Commit

Permalink
PEP extraction algorithms: updated tests to work with new implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
richrobe committed Jan 3, 2025
1 parent a4214a4 commit 641ee88
Show file tree
Hide file tree
Showing 11 changed files with 32 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def _get_most_prominent_monotonic_increasing_segment(icg_segment: pd.Series, hei
].index

end_index_drop_rule_b = end_index_drop_rule_b.union(end_index_drop_rule_b - 1)
monotony_df = monotony_df.drop(index=end_index_drop_rule_b)
monotony_df = monotony_df.drop(index=monotony_df.iloc[end_index_drop_rule_b].index)

# Select the monotonic segment with the highest amplitude difference
start_sample = 0
Expand Down
2 changes: 1 addition & 1 deletion tests/test_algorithms/test_b_point_extraction_arbol2017.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_regression_extract_series(self):
@staticmethod
def _get_regression_reference():
data = pd.read_csv(TEST_FILE_PATH.joinpath("pep_test_b_point_reference_arbol2017.csv"), index_col=0)
data = data.convert_dtypes(infer_objects=True)
data = data.astype({"b_point_sample": "Int64", "nan_reason": "object"})
return data

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_regression_extract_series(self):
@staticmethod
def _get_regression_reference():
data = pd.read_csv(TEST_FILE_PATH.joinpath("pep_test_b_point_reference_debski1993.csv"), index_col=0)
data = data.convert_dtypes(infer_objects=True)
data = data.astype({"b_point_sample": "Int64", "nan_reason": "object"})
return data

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion tests/test_algorithms/test_b_point_extraction_drost2022.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_regression_extract_series(self):
@staticmethod
def _get_regression_reference():
data = pd.read_csv(TEST_FILE_PATH.joinpath("pep_test_b_point_reference_drost2022.csv"), index_col=0)
data = data.convert_dtypes(infer_objects=True)
data = data.astype({"b_point_sample": "Int64", "nan_reason": "object"})
return data

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_regression_extract_series(self):
@staticmethod
def _get_regression_reference():
data = pd.read_csv(TEST_FILE_PATH.joinpath("pep_test_b_point_reference_forouzanfar2018.csv"), index_col=0)
data = data.convert_dtypes(infer_objects=True)
data = data.astype({"b_point_sample": "Int64", "nan_reason": "object"})
return data

@staticmethod
Expand Down
21 changes: 2 additions & 19 deletions tests/test_algorithms/test_c_point_extraction_scipy_findpeaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_regression_extract_series(self):
@staticmethod
def _get_regression_reference():
data = pd.read_csv(TEST_FILE_PATH.joinpath("pep_test_c_point_reference_scipy_findpeaks.csv"), index_col=0)
data = data.convert_dtypes(infer_objects=True)
data = data.astype({"c_point_sample": "Int64", "nan_reason": "object"})
return data

@staticmethod
Expand All @@ -78,7 +78,6 @@ class TestCPointExtractionSciPyFindpeaksParameters:
def setup(
self,
window_c_correction: Optional[int] = 3,
save_candidates: Optional[bool] = False,
):
# Sample ECG data
self.ecg_data = pd.read_csv(TEST_FILE_PATH.joinpath("pep_test_ecg.csv"), index_col=0)
Expand All @@ -88,9 +87,7 @@ def setup(
self.heartbeats = self.segmenter.extract(
ecg=self.ecg_data, sampling_rate_hz=self.sampling_rate_hz
).heartbeat_list_
self.extract_algo = CPointExtractionScipyFindPeaks(
window_c_correction=window_c_correction, save_candidates=save_candidates
)
self.extract_algo = CPointExtractionScipyFindPeaks(window_c_correction=window_c_correction)
self.test_case = unittest.TestCase()

@pytest.mark.parametrize(
Expand All @@ -102,20 +99,6 @@ def test_extract_window_c_correction(self, window_c_correction):

self.extract_algo.extract(icg=self.icg_data, heartbeats=self.heartbeats, sampling_rate_hz=self.sampling_rate_hz)

print(self.extract_algo.points_)

assert isinstance(self.extract_algo.points_, pd.DataFrame)
assert "c_point_sample" in self.extract_algo.points_.columns
assert "nan_reason" in self.extract_algo.points_.columns

@pytest.mark.parametrize(
("save_candidates", "expected_columns"),
[(True, ["c_point_sample", "nan_reason", "c_point_candidates"]), (False, ["c_point_sample", "nan_reason"])],
)
def test_extract_window_(self, save_candidates, expected_columns):
self.setup(save_candidates=save_candidates)

self.extract_algo.extract(icg=self.icg_data, heartbeats=self.heartbeats, sampling_rate_hz=self.sampling_rate_hz)

assert isinstance(self.extract_algo.points_, pd.DataFrame)
self.test_case.assertListEqual(expected_columns, self.extract_algo.points_.columns.tolist())
6 changes: 1 addition & 5 deletions tests/test_algorithms/test_heatbeat_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,6 @@ def test_regression_extract_variable_length_dataframe(self):
_assert_is_dtype(ecg_data, pd.DataFrame)

self.segmenter.extract(ecg=ecg_data, sampling_rate_hz=self.sampling_rate_hz)

# print(self.segmenter.heartbeat_list_["start_time"].dtype)
# print(reference_heartbeats["start_time"].dtype)

# check if the extraction is equal
self._check_heartbeats_equal(reference_heartbeats, self.segmenter.heartbeat_list_)

Expand Down Expand Up @@ -136,7 +132,7 @@ def _check_heartbeats_equal(reference_heartbeats, extracted_heartbeats):
("data", "expected"),
[
(None, pytest.raises(ValueError)),
(pd.Series([]), pytest.raises(ValueError)),
(pd.Series([], dtype="Float64"), pytest.raises(ValueError)),
(pd.DataFrame(), pytest.raises(ValidationError)),
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _get_b_point_outlier_middle(self):

def _get_regression_reference(self):
data = pd.read_csv(TEST_FILE_PATH.joinpath("pep_test_icg_outlier_correction_forouzanfar2018.csv"), index_col=0)
data = data.convert_dtypes(infer_objects=True)
data = data.astype({"b_point_sample": "Int64", "nan_reason": "object"})
return data

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ def test_regression_correct_outlier(self, outlier_type):
b_points=b_points, c_points=self.c_points, sampling_rate_hz=self.sampling_rate_hz
)

print(self.outlier_algo.points_)

corrected_beats = (self.b_points - self.outlier_algo.points_)["b_point_sample"] != 0
corrected_beats = self.b_points.index[corrected_beats]

Expand All @@ -87,7 +85,7 @@ def _get_b_point_outlier_middle(self):

def _get_regression_reference(self):
data = pd.read_csv(TEST_FILE_PATH.joinpath("pep_test_icg_outlier_correction_interpolation.csv"), index_col=0)
data = data.convert_dtypes(infer_objects=True)
data = data.astype({"b_point_sample": "Int64", "nan_reason": "object"})
return data

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from contextlib import contextmanager
from pathlib import Path

import numpy as np
import pandas as pd
import pytest

Expand Down Expand Up @@ -35,44 +36,44 @@ def test_extract(self):
self.extract_algo.extract(ecg=self.ecg_data, heartbeats=self.heartbeats, sampling_rate_hz=self.sampling_rate_hz)

assert isinstance(self.extract_algo.points_, pd.DataFrame)
assert "q_wave_onset_sample" in self.extract_algo.points_.columns
assert "q_peak_sample" in self.extract_algo.points_.columns
assert "nan_reason" in self.extract_algo.points_.columns

# add regression test to check if the extracted q-wave onsets match with the saved reference
# add regression test to check if the extracted q-peaks match with the saved reference
def test_regression_extract_dataframe(self):
self.setup()

ecg_data = self.ecg_data
_assert_is_dtype(ecg_data, pd.DataFrame)

reference_q_wave_onsets = self._get_regression_reference()
reference_q_peaks = self._get_regression_reference()
self.extract_algo.extract(ecg=ecg_data, heartbeats=self.heartbeats, sampling_rate_hz=self.sampling_rate_hz)

self._check_q_wave_onset_equal(reference_q_wave_onsets, self.extract_algo.points_)
self._check_q_peaks_equal(reference_q_peaks, self.extract_algo.points_)

def test_regression_extract_series(self):
self.setup()

ecg_data = self.ecg_data.squeeze()
_assert_is_dtype(ecg_data, pd.Series)

reference_q_wave_onsets = self._get_regression_reference()
reference_q_peaks = self._get_regression_reference()
self.extract_algo.extract(ecg=ecg_data, heartbeats=self.heartbeats, sampling_rate_hz=self.sampling_rate_hz)

self._check_q_wave_onset_equal(reference_q_wave_onsets, self.extract_algo.points_)
self._check_q_peaks_equal(reference_q_peaks, self.extract_algo.points_)

@staticmethod
def _get_regression_reference():
data = pd.read_csv(TEST_FILE_PATH.joinpath("pep_test_q_wave_onset_reference_neurokit_dwt.csv"), index_col=0)
data = data.convert_dtypes(infer_objects=True)
data = pd.read_csv(TEST_FILE_PATH.joinpath("pep_test_q_peak_reference_neurokit_dwt.csv"), index_col=0)
data = data.astype({"q_peak_sample": "Int64", "nan_reason": "object"})
return data

@staticmethod
def _check_q_wave_onset_equal(reference_heartbeats, extracted_heartbeats):
def _check_q_peaks_equal(reference_heartbeats, extracted_heartbeats):
pd.testing.assert_frame_equal(reference_heartbeats, extracted_heartbeats)


class TestQWaveOnsetExtractionVanLien2013:
class TestQPeakExtractionVanLien2013:
def setup(self, time_interval_ms: int = 40):
# Sample ECG data
self.ecg_data = pd.read_csv(TEST_FILE_PATH.joinpath("pep_test_ecg.csv"), index_col=0)
Expand All @@ -98,9 +99,9 @@ def test_extract(self):
self.extract_algo.extract(ecg=self.ecg_data, heartbeats=self.heartbeats, sampling_rate_hz=self.sampling_rate_hz)

assert isinstance(self.extract_algo.points_, pd.DataFrame)
assert "q_wave_onset_sample" in self.extract_algo.points_.columns
assert "q_peak_sample" in self.extract_algo.points_.columns

# add regression test to check if the extracted q-wave onsets match with the saved reference
# add regression test to check if the extracted q-peaks match with the saved reference
@pytest.mark.parametrize(
("time_interval_ms"),
[34, 36, 38, 40],
Expand All @@ -111,9 +112,9 @@ def test_regression_extract_dataframe(self, time_interval_ms):
ecg_data = self.ecg_data
_assert_is_dtype(ecg_data, pd.DataFrame)

reference_q_wave_onsets = self._get_regression_reference(time_interval_ms)
reference_q_peaks = self._get_regression_reference(time_interval_ms)
self.extract_algo.extract(ecg=ecg_data, heartbeats=self.heartbeats, sampling_rate_hz=self.sampling_rate_hz)
self._check_q_wave_onset_equal(reference_q_wave_onsets, self.extract_algo.points_)
self._check_q_peaks_equal(reference_q_peaks, self.extract_algo.points_)

@pytest.mark.parametrize(
("time_interval_ms"),
Expand All @@ -125,20 +126,20 @@ def test_regression_extract_series(self, time_interval_ms):
ecg_data = self.ecg_data.squeeze()
_assert_is_dtype(ecg_data, pd.Series)

reference_q_wave_onsets = self._get_regression_reference(time_interval_ms)
reference_q_peaks = self._get_regression_reference(time_interval_ms)
self.extract_algo.extract(ecg=ecg_data, heartbeats=self.heartbeats, sampling_rate_hz=self.sampling_rate_hz)
self._check_q_wave_onset_equal(reference_q_wave_onsets, self.extract_algo.points_)
self._check_q_peaks_equal(reference_q_peaks, self.extract_algo.points_)

def _get_regression_reference(self, time_interval_ms: int = 40):
data = pd.read_csv(
TEST_FILE_PATH.joinpath("pep_test_heartbeat_reference_variable_length.csv"), index_col=0, parse_dates=True
)
data = data.convert_dtypes(infer_objects=True)
data = data[["r_peak_sample"]] - int((time_interval_ms / self.sampling_rate_hz) * 1000)
data.columns = ["q_wave_onset_sample"]

data = data.assign(nan_reason=np.NAN)
data.columns = ["q_peak_sample", "nan_reason"]
data = data.astype({"q_peak_sample": "Int64", "nan_reason": "object"})
return data

@staticmethod
def _check_q_wave_onset_equal(reference_heartbeats, extracted_heartbeats):
def _check_q_peaks_equal(reference_heartbeats, extracted_heartbeats):
pd.testing.assert_frame_equal(reference_heartbeats, extracted_heartbeats)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
heartbeat_id,q_wave_onset_sample,nan_reason
heartbeat_id,q_peak_sample,nan_reason
0,423,
1,1012,
2,1614,
Expand Down

0 comments on commit 641ee88

Please sign in to comment.