Skip to content

Commit

Permalink
Single threaded evaluation (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
joeranbosma authored Nov 9, 2022
1 parent 140b864 commit 8d24402
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 62 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def run(self):
long_description = fh.read()

setuptools.setup(
version='1.4.3', # also update version in metrics.py -> version
version='1.4.4', # also update version in metrics.py -> version
author_email='[email protected]',
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
133 changes: 84 additions & 49 deletions src/picai_eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def evaluate_case(
allow_unmatched_candidates_with_minimal_overlap: bool = True,
y_det_postprocess_func: "Optional[Callable[[npt.NDArray[np.float32]], npt.NDArray[np.float32]]]" = None,
y_true_postprocess_func: "Optional[Callable[[npt.NDArray[np.int32]], npt.NDArray[np.int32]]]" = None,
weight: Optional[float] = None,
idx: Optional[str] = None,
) -> Tuple[List[Tuple[int, float, float]], float]:
"""
Gather the list of lesion candidates, and classify in TP/FP/FN.
Expand Down Expand Up @@ -186,7 +188,47 @@ def evaluate_case(
# apply user-defines case-level confidence score function
case_confidence = case_confidence_func(y_det)

return y_list, case_confidence
return y_list, case_confidence, weight, idx


def make_evaluation_iterator(
y_det: "Iterable[Union[npt.NDArray[np.float64], str, Path]]",
y_true: "Iterable[Union[npt.NDArray[np.float64], str, Path]]",
sample_weight: "Optional[Iterable[float]]" = None,
subject_list: Optional[Iterable[Hashable]] = None,
num_threads: int = 3,
**kwargs
):
if num_threads >= 2:
# process the cases in parallel
with ThreadPoolExecutor(max_workers=num_threads) as pool:
futures = {
pool.submit(
evaluate_case,
y_det=y_det_case,
y_true=y_true_case,
weight=weight,
idx=idx,
**kwargs
): idx
for (y_det_case, y_true_case, weight, idx) in zip(y_det, y_true, sample_weight, subject_list)
}

iterator = concurrent.futures.as_completed(futures)
else:
# process the cases sequentially
def func(y_det_case, y_true_case, weight, idx):
return evaluate_case(
y_det=y_det_case,
y_true=y_true_case,
weight=weight,
idx=idx,
**kwargs
)

iterator = map(func, y_det, y_true, sample_weight, subject_list)

return iterator


# Evaluate all cases
Expand Down Expand Up @@ -234,7 +276,7 @@ def evaluate(
lesion candidates from a softmax prediction volume.
- y_true_postprocess_func: function to apply to annotation. Can for example be used to select the lesion
masks from annotations that also contain other structures (such as organ segmentations).
- num_parallel_calls: number of threads to use for evaluation.
- num_parallel_calls: number of threads to use for evaluation. Set to 1 to disable parallelization.
- verbose: (optional) controll amount of printed information.
Returns:
Expand All @@ -253,54 +295,47 @@ def evaluate(
lesion_results: Dict[Hashable, List[Tuple[int, float, float]]] = {}
lesion_weight: Dict[Hashable, List[float]] = {}

with ThreadPoolExecutor(max_workers=num_parallel_calls) as pool:
# define the functions that need to be processed: compute_pred_vector, with each individual
# detection_map prediction, ground truth label and parameters
future_to_args = {
pool.submit(
evaluate_case,
y_det=y_det_case,
y_true=y_true_case,
min_overlap=min_overlap,
overlap_func=overlap_func,
case_confidence_func=case_confidence_func,
allow_unmatched_candidates_with_minimal_overlap=allow_unmatched_candidates_with_minimal_overlap,
y_det_postprocess_func=y_det_postprocess_func,
y_true_postprocess_func=y_true_postprocess_func
): (idx, weight)
for (y_det_case, y_true_case, weight, idx) in zip(y_det, y_true, sample_weight, subject_list)
}
iterator = make_evaluation_iterator(
y_det=y_det,
y_true=y_true,
sample_weight=sample_weight,
subject_list=subject_list,
num_threads=num_parallel_calls,
min_overlap=min_overlap,
overlap_func=overlap_func,
case_confidence_func=case_confidence_func,
allow_unmatched_candidates_with_minimal_overlap=allow_unmatched_candidates_with_minimal_overlap,
y_det_postprocess_func=y_det_postprocess_func,
y_true_postprocess_func=y_true_postprocess_func
)

# process the cases in parallel
iterator = concurrent.futures.as_completed(future_to_args)
if verbose:
total: Optional[int] = None
if isinstance(subject_list, Sized):
total = len(subject_list)
iterator = tqdm(iterator, desc='Evaluating', total=total)

for future in iterator:
idx, weight = future_to_args[future]

try:
# unpack results
lesion_results_case, case_confidence = future.result()
except Exception as e:
print(f'Error for {idx}: {e}')
raise e

# aggregate results
idx, weight = future_to_args[future]
case_weight[idx] = weight
case_pred[idx] = case_confidence
if len(lesion_results_case):
case_target[idx] = np.max([a[0] for a in lesion_results_case])
else:
case_target[idx] = 0

# accumulate outputs
lesion_results[idx] = lesion_results_case
lesion_weight[idx] = [weight] * len(lesion_results_case)
if verbose:
total: Optional[int] = None
if isinstance(subject_list, Sized):
total = len(subject_list)
iterator = tqdm(iterator, desc='Evaluating', total=total)

for result in iterator:
if isinstance(result, tuple):
# single-threaded evaluation
lesion_results_case, case_confidence, weight, idx = result
elif isinstance(result, concurrent.futures.Future):
# multi-threaded evaluation
lesion_results_case, case_confidence, weight, idx = result.result()
else:
raise TypeError(f'Unexpected result type: {type(result)}')

# aggregate results
case_weight[idx] = weight
case_pred[idx] = case_confidence
if len(lesion_results_case):
case_target[idx] = np.max([a[0] for a in lesion_results_case])
else:
case_target[idx] = 0

# accumulate outputs
lesion_results[idx] = lesion_results_case
lesion_weight[idx] = [weight] * len(lesion_results_case)

# collect results in a Metrics object
metrics = Metrics(
Expand Down
3 changes: 2 additions & 1 deletion tests/test_case_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
from numpy.testing import assert_allclose

from picai_eval.eval import evaluate_case


Expand Down Expand Up @@ -42,7 +43,7 @@ def test_evaluate_case(y_det, y_true, expected_y_list):
Evaluate predictions and calculate FROC statistics
"""

y_list, _ = evaluate_case(
y_list, *_ = evaluate_case(
y_det=os.path.join("tests/test-maps", y_det),
y_true=os.path.join("tests/test-maps", y_true),
)
Expand Down
37 changes: 26 additions & 11 deletions tests/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import numpy as np
import pytest
from numpy.testing import assert_allclose
from sklearn.metrics import auc, roc_curve

from picai_eval import evaluate
from picai_eval.data_utils import load_metrics
from picai_eval.eval import Metrics, evaluate_folder
from picai_eval.image_utils import (read_label, read_prediction,
resize_image_with_crop_or_pad)
from sklearn.metrics import auc, roc_curve

subject_list = [
f"case-{i}"
Expand Down Expand Up @@ -59,7 +60,7 @@ def y_true():
yield y_true


def test_evaluation(y_det, y_true):
def test_evaluation(y_det, y_true, num_parallel_calls=3):
"""
Test standard evaluation pipeline
The 10 crafted cases in subject_list should have:
Expand All @@ -71,7 +72,8 @@ def test_evaluation(y_det, y_true):
metrics = evaluate(
y_det=y_det,
y_true=y_true,
subject_list=subject_list
subject_list=subject_list,
num_parallel_calls=num_parallel_calls,
)

# check metrics
Expand Down Expand Up @@ -184,7 +186,7 @@ def test_sample_weights():
@pytest.mark.xfail
def test_evaluation_negative_predictions(y_det, y_true):
"""
Test if evaluation works properly if (some) confidence scores are negative
Test if evaluation works if (some) confidence scores are negative
"""
# shift confidence scores to include both positive and negative confidences
shifted_y_det = [y_det - 1.5 for y_det in y_det]
Expand All @@ -197,25 +199,31 @@ def test_evaluation_negative_predictions(y_det, y_true):
evaluate(
y_det=shifted_y_det,
y_true=y_true,
subject_list=subject_list
subject_list=subject_list,
)


@pytest.mark.xfail
def test_softmax_input(y_det, y_true):
def test_softmax_input(y_det, y_true, num_parallel_calls=3):
"""
Test if evaluation throws an error when the input is a softmax volume (instead of detection maps)
"""
evaluate(
y_det=[np.random.normal(size=pred.shape) for pred in y_det],
y_true=y_true,
subject_list=subject_list
subject_list=subject_list,
num_parallel_calls=num_parallel_calls,
)


def test_evaluation_from_dir_with_subject_list():
def test_evaluation_from_dir_with_subject_list(num_parallel_calls=3):
detection_map_dir = "tests/test-maps"
metrics = evaluate_folder(detection_map_dir, subject_list=subject_list, verbose=1)
metrics = evaluate_folder(
detection_map_dir,
subject_list=subject_list,
num_parallel_calls=num_parallel_calls,
verbose=1,
)

# check metrics
assert metrics.lesion_TP[-2] == 5
Expand All @@ -224,9 +232,9 @@ def test_evaluation_from_dir_with_subject_list():
assert metrics.AP == (5/9)*(1/2) + 0*(1/2)


def test_evaluation_from_dir_without_subject_list():
def test_evaluation_from_dir_without_subject_list(num_parallel_calls=3):
detection_map_dir = "tests/test-maps"
metrics = evaluate_folder(detection_map_dir, verbose=1)
metrics = evaluate_folder(detection_map_dir, num_parallel_calls=num_parallel_calls, verbose=1)

# check metrics
assert metrics.lesion_TP[-2] == 5
Expand Down Expand Up @@ -324,3 +332,10 @@ def test_select_subset_lesion_results():
lesion_results_subset = metrics.get_lesion_results_flat(subject_list=["2", "3"])

assert lesion_results_subset_expected == lesion_results_subset


def test_single_threaded(y_det, y_true):
"""Test if single threaded evaluation works"""
test_evaluation(y_det=y_det, y_true=y_true, num_parallel_calls=1)
test_evaluation_from_dir_with_subject_list(num_parallel_calls=1)
test_evaluation_from_dir_without_subject_list(num_parallel_calls=1)

0 comments on commit 8d24402

Please sign in to comment.