From 2e9802b5779374de1e5d06f4e48e04744bcca85b Mon Sep 17 00:00:00 2001 From: Mark Eastwood <20169086+measty@users.noreply.github.com> Date: Fri, 13 Oct 2023 13:04:34 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9D=87=EF=B8=8F=20=20Add=20Convert=20`patche?= =?UTF-8?q?s`=20Output=20to=20`AnnotationStore`=20(#718)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Adds a function `patch_pred_store` to convert the output from `PatchPredictor` into an `AnnotationStore`. --------- Co-authored-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tests/test_utils.py | 69 ++++++++++++++++++++++++++++++++++++++++ tiatoolbox/utils/misc.py | 68 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 135 insertions(+), 2 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 4ba721569..0e3cdf93c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -18,6 +18,7 @@ from tests.test_annotation_stores import cell_polygon from tiatoolbox import utils +from tiatoolbox.annotation.storage import SQLiteStore from tiatoolbox.models.architecture import fetch_pretrained_weights from tiatoolbox.utils import misc from tiatoolbox.utils.exceptions import FileNotSupportedError @@ -734,6 +735,7 @@ def test_sub_pixel_read_incorrect_read_func_return() -> None: image = np.ones((10, 10)) def read_func(*args: tuple, **kwargs: dict) -> np.ndarray: # noqa: ARG001 + """Dummy read function for tests.""" return np.ones((5, 5)) with pytest.raises(ValueError, match="incorrect size"): @@ -752,6 +754,7 @@ def test_sub_pixel_read_empty_read_func_return() -> None: image = np.ones((10, 10)) def read_func(*args: tuple, **kwargs: dict) -> np.ndarray: # noqa: ARG001 + """Dummy read function for tests.""" return np.ones((0, 0)) with pytest.raises(ValueError, match="is empty"): @@ -1642,3 +1645,69 @@ def test_imwrite(tmp_path: Path) -> NoReturn: tmp_path / "thisfolderdoesnotexist" / "test_imwrite.jpg", img, ) + + +def test_patch_pred_store() -> None: + """Test patch_pred_store.""" + # Define a mock patch_output + patch_output = { + "predictions": [1, 0, 1], + "coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)], + "other": "other", + } + + store = misc.patch_pred_store(patch_output, (1.0, 1.0)) + + # Check that its an SQLiteStore containing the expected annotations + assert isinstance(store, SQLiteStore) + assert len(store) == 3 + for annotation in store.values(): + assert annotation.geometry.area == 1 + assert annotation.properties["type"] in [0, 1] + assert "other" not in annotation.properties + + patch_output.pop("coordinates") + # check correct error is raised if coordinates are missing + with pytest.raises(ValueError, match="coordinates"): + misc.patch_pred_store(patch_output, (1.0, 1.0)) + + +def test_patch_pred_store_cdict() -> None: + """Test patch_pred_store with a class dict.""" + # Define a mock patch_output + patch_output = { + "predictions": [1, 0, 1], + "coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)], + "probabilities": [[0.1, 0.9], [0.9, 0.1], [0.4, 0.6]], + "labels": [1, 0, 1], + "other": "other", + } + class_dict = {0: "class0", 1: "class1"} + store = misc.patch_pred_store(patch_output, (1.0, 1.0), class_dict=class_dict) + + # Check that its an SQLiteStore containing the expected annotations + assert isinstance(store, SQLiteStore) + assert len(store) == 3 + for annotation in store.values(): + assert annotation.geometry.area == 1 + assert annotation.properties["label"] in ["class0", "class1"] + assert annotation.properties["type"] in ["class0", "class1"] + assert "other" not in annotation.properties + + +def test_patch_pred_store_sf() -> None: + """Test patch_pred_store with scale factor.""" + # Define a mock patch_output + patch_output = { + "predictions": [1, 0, 1], + "coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)], + "probabilities": [[0.1, 0.9], [0.9, 0.1], [0.4, 0.6]], + "labels": [1, 0, 1], + } + store = misc.patch_pred_store(patch_output, (2.0, 2.0)) + + # Check that its an SQLiteStore containing the expected annotations + assert isinstance(store, SQLiteStore) + assert len(store) == 3 + for annotation in store.values(): + assert annotation.geometry.area == 4 diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 89e60970e..ebe6c198e 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -18,6 +18,7 @@ import yaml from filelock import FileLock from shapely.affinity import translate +from shapely.geometry import Polygon from shapely.geometry import shape as feature2geometry from skimage import exposure @@ -860,7 +861,8 @@ def select_device(*, on_gpu: bool) -> str: """Selects the appropriate device as requested. Args: - on_gpu (bool): Selects gpu if True. + on_gpu (bool): + Selects gpu if True. Returns: str: @@ -883,7 +885,6 @@ def model_to(model: torch.nn.Module, *, on_gpu: bool) -> torch.nn.Module: Returns: torch.nn.Module: The model after being moved to cpu/gpu. - """ if on_gpu: # DataParallel work only for cuda model = torch.nn.DataParallel(model) @@ -1194,3 +1195,66 @@ def add_from_dat( logger.info("Added %d annotations.", len(anns)) store.append_many(anns) + + +def patch_pred_store( + patch_output: dict, + scale_factor: tuple[int, int], + class_dict: dict | None = None, +) -> AnnotationStore: + """Create an SQLiteStore containing Annotations for each patch. + + Args: + patch_output (dict): A dictionary of patch prediction information. Important + keys are "probabilities", "predictions", "coordinates", and "labels". + scale_factor (tuple[int, int]): The scale factor to use when loading the + annotations. All coordinates will be multiplied by this factor to allow + conversion of annotations saved at non-baseline resolution to baseline. + Should be model_mpp/slide_mpp. + class_dict (dict): Optional dictionary mapping class indices to class names. + + Returns: + SQLiteStore: An SQLiteStore containing Annotations for each patch. + + """ + if "coordinates" not in patch_output: + # we cant create annotations without coordinates + msg = "Patch output must contain coordinates." + raise ValueError(msg) + # get relevant keys + class_probs = patch_output.get("probabilities", []) + preds = patch_output.get("predictions", []) + patch_coords = np.array(patch_output.get("coordinates", [])) + if not np.all(np.array(scale_factor) == 1): + patch_coords = patch_coords * (np.tile(scale_factor, 2)) # to baseline mpp + labels = patch_output.get("labels", []) + # get classes to consider + if len(class_probs) == 0: + classes_predicted = np.unique(preds).tolist() + else: + classes_predicted = range(len(class_probs[0])) + if class_dict is None: + # if no class dict create a default one + class_dict = {i: i for i in np.unique(preds + labels).tolist()} + annotations = [] + # find what keys we need to save + keys = ["predictions"] + keys = keys + [key for key in ["probabilities", "labels"] if key in patch_output] + + # put patch predictions into a store + annotations = [] + for i, pred in enumerate(preds): + if "probabilities" in keys: + props = { + f"prob_{class_dict[j]}": class_probs[i][j] for j in classes_predicted + } + else: + props = {} + if "labels" in keys: + props["label"] = class_dict[labels[i]] + props["type"] = class_dict[pred] + annotations.append(Annotation(Polygon.from_bounds(*patch_coords[i]), props)) + store = SQLiteStore() + keys = store.append_many(annotations, [str(i) for i in range(len(annotations))]) + + return store