Skip to content

Commit

Permalink
updated unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
richrobe committed Jan 8, 2025
1 parent 53185a0 commit 0145720
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def extract(
elif self.handle_missing_events == "raise":
raise EventExtractionError(missing_str)

is_b_point_dataframe(b_points)
b_points = b_points.astype({"b_point_sample": "Int64", "nan_reason": "object"})
is_b_point_dataframe(b_points)

self.points_ = b_points
return self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ def extract(

b_points.loc[idx, "b_point_sample"] = b_point_sample

is_b_point_dataframe(b_points)
b_points = b_points.astype({"b_point_sample": "Int64", "nan_reason": "object"})
is_b_point_dataframe(b_points)

self.points_ = b_points
return self
12 changes: 6 additions & 6 deletions tests/test_algorithms/test_b_point_extraction_arbol2017.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from pathlib import Path

import pandas as pd
import pytest

from biopsykit.signals.ecg.segmentation._heartbeat_segmentation import HeartbeatSegmentationNeurokit
from biopsykit.signals.icg.event_extraction import (
BPointExtractionArbol2017ThirdDerivative,
CPointExtractionScipyFindPeaks,
)
from biopsykit.utils._datatype_validation_helper import _assert_is_dtype
from biopsykit.utils.exceptions import ValidationError

TEST_FILE_PATH = Path(__file__).parent.joinpath("../test_data/pep")

Expand Down Expand Up @@ -70,12 +72,10 @@ def test_regression_extract_series(self):
icg_data = self.icg_data.squeeze()
_assert_is_dtype(icg_data, pd.Series)

reference_b_points = self._get_regression_reference()
self.extract_algo.extract(
icg=icg_data, heartbeats=self.heartbeats, c_points=self.c_points, sampling_rate_hz=self.sampling_rate_hz
)

self._check_b_point_equal(reference_b_points, self.extract_algo.points_)
with pytest.raises(ValidationError):
self.extract_algo.extract(
icg=icg_data, heartbeats=self.heartbeats, c_points=self.c_points, sampling_rate_hz=self.sampling_rate_hz
)

@staticmethod
def _get_regression_reference():
Expand Down
12 changes: 6 additions & 6 deletions tests/test_algorithms/test_b_point_extraction_debski1993.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from pathlib import Path

import pandas as pd
import pytest

from biopsykit.signals.ecg.segmentation._heartbeat_segmentation import HeartbeatSegmentationNeurokit
from biopsykit.signals.icg.event_extraction import (
BPointExtractionDebski1993SecondDerivative,
CPointExtractionScipyFindPeaks,
)
from biopsykit.utils._datatype_validation_helper import _assert_is_dtype
from biopsykit.utils.exceptions import ValidationError

TEST_FILE_PATH = Path(__file__).parent.joinpath("../test_data/pep")

Expand Down Expand Up @@ -70,12 +72,10 @@ def test_regression_extract_series(self):
icg_data = self.icg_data.squeeze()
_assert_is_dtype(icg_data, pd.Series)

reference_b_points = self._get_regression_reference()
self.extract_algo.extract(
icg=icg_data, heartbeats=self.heartbeats, c_points=self.c_points, sampling_rate_hz=self.sampling_rate_hz
)

self._check_b_point_equal(reference_b_points, self.extract_algo.points_)
with pytest.raises(ValidationError):
self.extract_algo.extract(
icg=icg_data, heartbeats=self.heartbeats, c_points=self.c_points, sampling_rate_hz=self.sampling_rate_hz
)

@staticmethod
def _get_regression_reference():
Expand Down
12 changes: 6 additions & 6 deletions tests/test_algorithms/test_b_point_extraction_drost2022.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from pathlib import Path

import pandas as pd
import pytest

from biopsykit.signals.ecg.segmentation._heartbeat_segmentation import HeartbeatSegmentationNeurokit
from biopsykit.signals.icg.event_extraction import (
BPointExtractionDrost2022,
CPointExtractionScipyFindPeaks,
)
from biopsykit.utils._datatype_validation_helper import _assert_is_dtype
from biopsykit.utils.exceptions import ValidationError

TEST_FILE_PATH = Path(__file__).parent.joinpath("../test_data/pep")

Expand Down Expand Up @@ -70,12 +72,10 @@ def test_regression_extract_series(self):
icg_data = self.icg_data.squeeze()
_assert_is_dtype(icg_data, pd.Series)

reference_b_points = self._get_regression_reference()
self.extract_algo.extract(
icg=icg_data, heartbeats=self.heartbeats, c_points=self.c_points, sampling_rate_hz=self.sampling_rate_hz
)

self._check_b_point_equal(reference_b_points, self.extract_algo.points_)
with pytest.raises(ValidationError):
self.extract_algo.extract(
icg=icg_data, heartbeats=self.heartbeats, c_points=self.c_points, sampling_rate_hz=self.sampling_rate_hz
)

@staticmethod
def _get_regression_reference():
Expand Down
12 changes: 6 additions & 6 deletions tests/test_algorithms/test_b_point_extraction_forouzanfar2018.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from pathlib import Path

import pandas as pd
import pytest

from biopsykit.signals.ecg.segmentation._heartbeat_segmentation import HeartbeatSegmentationNeurokit
from biopsykit.signals.icg.event_extraction import (
BPointExtractionForouzanfar2018,
CPointExtractionScipyFindPeaks,
)
from biopsykit.utils._datatype_validation_helper import _assert_is_dtype
from biopsykit.utils.exceptions import ValidationError

TEST_FILE_PATH = Path(__file__).parent.joinpath("../test_data/pep")

Expand Down Expand Up @@ -70,12 +72,10 @@ def test_regression_extract_series(self):
icg_data = self.icg_data.squeeze()
_assert_is_dtype(icg_data, pd.Series)

reference_b_points = self._get_regression_reference()
self.extract_algo.extract(
icg=icg_data, heartbeats=self.heartbeats, c_points=self.c_points, sampling_rate_hz=self.sampling_rate_hz
)

self._check_b_point_equal(reference_b_points, self.extract_algo.points_)
with pytest.raises(ValidationError):
self.extract_algo.extract(
icg=icg_data, heartbeats=self.heartbeats, c_points=self.c_points, sampling_rate_hz=self.sampling_rate_hz
)

@staticmethod
def _get_regression_reference():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from biopsykit.signals.ecg.segmentation._heartbeat_segmentation import HeartbeatSegmentationNeurokit
from biopsykit.signals.icg.event_extraction import CPointExtractionScipyFindPeaks
from biopsykit.utils._datatype_validation_helper import _assert_is_dtype
from biopsykit.utils.exceptions import ValidationError

TEST_FILE_PATH = Path(__file__).parent.joinpath("../test_data/pep")

Expand Down Expand Up @@ -58,10 +59,8 @@ def test_regression_extract_series(self):
icg_data = self.icg_data.squeeze()
_assert_is_dtype(icg_data, pd.Series)

reference_c_points = self._get_regression_reference()
self.extract_algo.extract(icg=icg_data, heartbeats=self.heartbeats, sampling_rate_hz=self.sampling_rate_hz)

self._check_c_point_equal(reference_c_points, self.extract_algo.points_)
with pytest.raises(ValidationError):
self.extract_algo.extract(icg=icg_data, heartbeats=self.heartbeats, sampling_rate_hz=self.sampling_rate_hz)

@staticmethod
def _get_regression_reference():
Expand Down
38 changes: 17 additions & 21 deletions tests/test_algorithms/test_heatbeat_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,11 @@ def test_regression_extract_variable_length_dataframe(self):
def test_regression_extract_variable_length_series(self):
self.setup()

reference_heartbeats = self._get_regression_reference("pep_test_heartbeat_reference_variable_length.csv")

ecg_data = self.ecg_data["ecg"]
_assert_is_dtype(ecg_data, pd.Series)

self.segmenter.extract(ecg=ecg_data, sampling_rate_hz=self.sampling_rate_hz)
# check if the first heartbeat is correct
self._check_heartbeats_equal(reference_heartbeats, self.segmenter.heartbeat_list_)
with pytest.raises(ValidationError):
self.segmenter.extract(ecg=ecg_data, sampling_rate_hz=self.sampling_rate_hz)

def test_regression_extract_fixed_length_dataframe(self):
self.setup()
Expand All @@ -93,34 +90,33 @@ def test_regression_extract_fixed_length_dataframe(self):
def test_regression_extract_fixed_length_series(self):
self.setup(variable_length=False)

reference_heartbeats = self._get_regression_reference("pep_test_heartbeat_reference_fixed_length.csv")

ecg_data = self.ecg_data["ecg"]
_assert_is_dtype(ecg_data, pd.Series)

self.segmenter.extract(ecg=ecg_data, sampling_rate_hz=self.sampling_rate_hz)
# check if the first heartbeat is correct
self._check_heartbeats_equal(reference_heartbeats, self.segmenter.heartbeat_list_)
with pytest.raises(ValidationError):
self.segmenter.extract(ecg=ecg_data, sampling_rate_hz=self.sampling_rate_hz)

def test_regression_extract_fixed_length_numpy(self):
self.setup(variable_length=False)

reference_heartbeats = self._get_regression_reference("pep_test_heartbeat_reference_fixed_length.csv")

ecg_data = self.ecg_data["ecg"].to_numpy()
_assert_is_dtype(ecg_data, np.ndarray)

self.segmenter.extract(ecg=ecg_data, sampling_rate_hz=self.sampling_rate_hz)
# check if the first heartbeat is correct
estimated_heartbeats = self.segmenter.heartbeat_list_.drop(columns="start_time")
reference_heartbeats = reference_heartbeats.drop(columns="start_time")

self._check_heartbeats_equal(reference_heartbeats, estimated_heartbeats)
with pytest.raises(ValidationError):
self.segmenter.extract(ecg=ecg_data, sampling_rate_hz=self.sampling_rate_hz)

@staticmethod
def _get_regression_reference(file_path):
data = pd.read_csv(TEST_FILE_PATH.joinpath(file_path), index_col=0, parse_dates=True)
data = data.convert_dtypes(infer_objects=True)
data = data.astype(
{
"start_sample": "Int64",
"end_sample": "Int64",
"r_peak_sample": "Int64",
"rr_interval_sample": "Int64",
"rr_interval_ms": "Float64",
}
)
data["start_time"] = pd.to_datetime(data["start_time"]).dt.tz_convert("Europe/Berlin")
return data

Expand All @@ -131,8 +127,8 @@ def _check_heartbeats_equal(reference_heartbeats, extracted_heartbeats):
@pytest.mark.parametrize(
("data", "expected"),
[
(None, pytest.raises(ValueError)),
(pd.Series([], dtype="Float64"), pytest.raises(ValueError)),
(None, pytest.raises(ValidationError)),
(pd.Series([], dtype="Float64"), pytest.raises(ValidationError)),
(pd.DataFrame(), pytest.raises(ValidationError)),
],
)
Expand Down
12 changes: 5 additions & 7 deletions tests/test_algorithms/test_q_peak_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from biopsykit.signals.ecg.event_extraction import QPeakExtractionMartinez2004Neurokit, QPeakExtractionVanLien2013
from biopsykit.signals.ecg.segmentation._heartbeat_segmentation import HeartbeatSegmentationNeurokit
from biopsykit.utils._datatype_validation_helper import _assert_is_dtype
from biopsykit.utils.exceptions import ValidationError

TEST_FILE_PATH = Path(__file__).parent.joinpath("../test_data/pep")

Expand Down Expand Up @@ -57,10 +58,8 @@ def test_regression_extract_series(self):
ecg_data = self.ecg_data.squeeze()
_assert_is_dtype(ecg_data, pd.Series)

reference_q_peaks = self._get_regression_reference()
self.extract_algo.extract(ecg=ecg_data, heartbeats=self.heartbeats, sampling_rate_hz=self.sampling_rate_hz)

self._check_q_peaks_equal(reference_q_peaks, self.extract_algo.points_)
with pytest.raises(ValidationError):
self.extract_algo.extract(ecg=ecg_data, heartbeats=self.heartbeats, sampling_rate_hz=self.sampling_rate_hz)

@staticmethod
def _get_regression_reference():
Expand Down Expand Up @@ -126,9 +125,8 @@ def test_regression_extract_series(self, time_interval_ms):
ecg_data = self.ecg_data.squeeze()
_assert_is_dtype(ecg_data, pd.Series)

reference_q_peaks = self._get_regression_reference(time_interval_ms)
self.extract_algo.extract(ecg=ecg_data, heartbeats=self.heartbeats, sampling_rate_hz=self.sampling_rate_hz)
self._check_q_peaks_equal(reference_q_peaks, self.extract_algo.points_)
with pytest.raises(ValidationError):
self.extract_algo.extract(ecg=ecg_data, heartbeats=self.heartbeats, sampling_rate_hz=self.sampling_rate_hz)

def _get_regression_reference(self, time_interval_ms: int = 40):
data = pd.read_csv(
Expand Down

0 comments on commit 0145720

Please sign in to comment.