Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…toolbox into add-bokeh-app
  • Loading branch information
measty committed Oct 13, 2023
2 parents 96c1bd9 + 2e9802b commit 6869433
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 2 deletions.
69 changes: 69 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from tests.test_annotation_stores import cell_polygon
from tiatoolbox import utils
from tiatoolbox.annotation.storage import SQLiteStore
from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.utils import misc
from tiatoolbox.utils.exceptions import FileNotSupportedError
Expand Down Expand Up @@ -734,6 +735,7 @@ def test_sub_pixel_read_incorrect_read_func_return() -> None:
image = np.ones((10, 10))

def read_func(*args: tuple, **kwargs: dict) -> np.ndarray: # noqa: ARG001
"""Dummy read function for tests."""
return np.ones((5, 5))

with pytest.raises(ValueError, match="incorrect size"):
Expand All @@ -752,6 +754,7 @@ def test_sub_pixel_read_empty_read_func_return() -> None:
image = np.ones((10, 10))

def read_func(*args: tuple, **kwargs: dict) -> np.ndarray: # noqa: ARG001
"""Dummy read function for tests."""
return np.ones((0, 0))

with pytest.raises(ValueError, match="is empty"):
Expand Down Expand Up @@ -1642,3 +1645,69 @@ def test_imwrite(tmp_path: Path) -> NoReturn:
tmp_path / "thisfolderdoesnotexist" / "test_imwrite.jpg",
img,
)


def test_patch_pred_store() -> None:
"""Test patch_pred_store."""
# Define a mock patch_output
patch_output = {
"predictions": [1, 0, 1],
"coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)],
"other": "other",
}

store = misc.patch_pred_store(patch_output, (1.0, 1.0))

# Check that its an SQLiteStore containing the expected annotations
assert isinstance(store, SQLiteStore)
assert len(store) == 3
for annotation in store.values():
assert annotation.geometry.area == 1
assert annotation.properties["type"] in [0, 1]
assert "other" not in annotation.properties

patch_output.pop("coordinates")
# check correct error is raised if coordinates are missing
with pytest.raises(ValueError, match="coordinates"):
misc.patch_pred_store(patch_output, (1.0, 1.0))


def test_patch_pred_store_cdict() -> None:
"""Test patch_pred_store with a class dict."""
# Define a mock patch_output
patch_output = {
"predictions": [1, 0, 1],
"coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)],
"probabilities": [[0.1, 0.9], [0.9, 0.1], [0.4, 0.6]],
"labels": [1, 0, 1],
"other": "other",
}
class_dict = {0: "class0", 1: "class1"}
store = misc.patch_pred_store(patch_output, (1.0, 1.0), class_dict=class_dict)

# Check that its an SQLiteStore containing the expected annotations
assert isinstance(store, SQLiteStore)
assert len(store) == 3
for annotation in store.values():
assert annotation.geometry.area == 1
assert annotation.properties["label"] in ["class0", "class1"]
assert annotation.properties["type"] in ["class0", "class1"]
assert "other" not in annotation.properties


def test_patch_pred_store_sf() -> None:
"""Test patch_pred_store with scale factor."""
# Define a mock patch_output
patch_output = {
"predictions": [1, 0, 1],
"coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)],
"probabilities": [[0.1, 0.9], [0.9, 0.1], [0.4, 0.6]],
"labels": [1, 0, 1],
}
store = misc.patch_pred_store(patch_output, (2.0, 2.0))

# Check that its an SQLiteStore containing the expected annotations
assert isinstance(store, SQLiteStore)
assert len(store) == 3
for annotation in store.values():
assert annotation.geometry.area == 4
68 changes: 66 additions & 2 deletions tiatoolbox/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import yaml
from filelock import FileLock
from shapely.affinity import translate
from shapely.geometry import Polygon
from shapely.geometry import shape as feature2geometry
from skimage import exposure

Expand Down Expand Up @@ -860,7 +861,8 @@ def select_device(*, on_gpu: bool) -> str:
"""Selects the appropriate device as requested.
Args:
on_gpu (bool): Selects gpu if True.
on_gpu (bool):
Selects gpu if True.
Returns:
str:
Expand All @@ -883,7 +885,6 @@ def model_to(model: torch.nn.Module, *, on_gpu: bool) -> torch.nn.Module:
Returns:
torch.nn.Module:
The model after being moved to cpu/gpu.
"""
if on_gpu: # DataParallel work only for cuda
model = torch.nn.DataParallel(model)
Expand Down Expand Up @@ -1194,3 +1195,66 @@ def add_from_dat(

logger.info("Added %d annotations.", len(anns))
store.append_many(anns)


def patch_pred_store(
patch_output: dict,
scale_factor: tuple[int, int],
class_dict: dict | None = None,
) -> AnnotationStore:
"""Create an SQLiteStore containing Annotations for each patch.
Args:
patch_output (dict): A dictionary of patch prediction information. Important
keys are "probabilities", "predictions", "coordinates", and "labels".
scale_factor (tuple[int, int]): The scale factor to use when loading the
annotations. All coordinates will be multiplied by this factor to allow
conversion of annotations saved at non-baseline resolution to baseline.
Should be model_mpp/slide_mpp.
class_dict (dict): Optional dictionary mapping class indices to class names.
Returns:
SQLiteStore: An SQLiteStore containing Annotations for each patch.
"""
if "coordinates" not in patch_output:
# we cant create annotations without coordinates
msg = "Patch output must contain coordinates."
raise ValueError(msg)
# get relevant keys
class_probs = patch_output.get("probabilities", [])
preds = patch_output.get("predictions", [])
patch_coords = np.array(patch_output.get("coordinates", []))
if not np.all(np.array(scale_factor) == 1):
patch_coords = patch_coords * (np.tile(scale_factor, 2)) # to baseline mpp
labels = patch_output.get("labels", [])
# get classes to consider
if len(class_probs) == 0:
classes_predicted = np.unique(preds).tolist()
else:
classes_predicted = range(len(class_probs[0]))
if class_dict is None:
# if no class dict create a default one
class_dict = {i: i for i in np.unique(preds + labels).tolist()}
annotations = []
# find what keys we need to save
keys = ["predictions"]
keys = keys + [key for key in ["probabilities", "labels"] if key in patch_output]

# put patch predictions into a store
annotations = []
for i, pred in enumerate(preds):
if "probabilities" in keys:
props = {
f"prob_{class_dict[j]}": class_probs[i][j] for j in classes_predicted
}
else:
props = {}
if "labels" in keys:
props["label"] = class_dict[labels[i]]
props["type"] = class_dict[pred]
annotations.append(Annotation(Polygon.from_bounds(*patch_coords[i]), props))
store = SQLiteStore()
keys = store.append_many(annotations, [str(i) for i in range(len(annotations))])

return store

0 comments on commit 6869433

Please sign in to comment.