From 4f3ab87bcc63e72f07733fcd4a7795a2951545bc Mon Sep 17 00:00:00 2001 From: Joeran Bosma Date: Thu, 3 Oct 2024 17:13:00 +0200 Subject: [PATCH] Account for new keyword in sklearn v1.5+ (#19) --- setup.py | 2 +- src/picai_eval/metrics.py | 20 +++++++++++++++----- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index 04a3bdb..c9b347f 100644 --- a/setup.py +++ b/setup.py @@ -63,7 +63,7 @@ def run(self): long_description = fh.read() setuptools.setup( - version='1.4.6', # also update version in metrics.py -> version + version='1.4.7', # also update version in metrics.py -> version author_email='Joeran.Bosma@radboudumc.nl', long_description=long_description, long_description_content_type="text/markdown", diff --git a/src/picai_eval/metrics.py b/src/picai_eval/metrics.py index 2d99e00..0ec9bbe 100644 --- a/src/picai_eval/metrics.py +++ b/src/picai_eval/metrics.py @@ -19,6 +19,8 @@ from typing import Any, Dict, Hashable, List, Optional, Tuple, Union import numpy as np +import sklearn +from packaging import version from sklearn.metrics import auc, precision_recall_curve, roc_curve try: @@ -265,11 +267,19 @@ def calculate_precision_recall(self, subject_list: Optional[List[str]] = None) - y_pred: "npt.NDArray[np.float64]" = np.array([pred for _, pred, *_ in lesion_y_list]) # calculate precision-recall curve - precision, recall, thresholds = precision_recall_curve( - y_true=y_true, - probas_pred=y_pred, - sample_weight=self.get_lesion_weight_flat(subject_list=subject_list) - ) + if version.parse(sklearn.__version__) >= version.parse("1.5"): + # in the future this if/else block can be removed, then set 1.5 as minimum in requirements.txt + precision, recall, thresholds = precision_recall_curve( + y_true=y_true, + y_score=y_pred, + sample_weight=self.get_lesion_weight_flat(subject_list=subject_list) + ) + else: + precision, recall, thresholds = precision_recall_curve( + y_true=y_true, + probas_pred=y_pred, + sample_weight=self.get_lesion_weight_flat(subject_list=subject_list) + ) # set precision to zero at a threshold of "zero", as those lesion # candidates are included just to convey the number of lesions to