From 62e9c100673a0f0a15c279c22c5c0a75fb0d51c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ludvig=20Bergenstr=C3=A5hle?= Date: Sun, 1 Nov 2020 22:43:31 +0100 Subject: [PATCH] improve masking --- xfuse/convert/image.py | 3 +- xfuse/convert/st.py | 2 +- xfuse/convert/visium.py | 2 +- xfuse/utility/core.py | 53 +++++++++++++++++++++++++++++++++- xfuse/utility/mask.py | 25 +++++++++------- xfuse/utility/visualization.py | 3 +- 6 files changed, 72 insertions(+), 16 deletions(-) diff --git a/xfuse/convert/image.py b/xfuse/convert/image.py index c1f8b24c..2f9cb6e0 100644 --- a/xfuse/convert/image.py +++ b/xfuse/convert/image.py @@ -4,8 +4,9 @@ import pandas as pd from PIL import Image +from ..utility.core import rescale from ..utility.mask import compute_tissue_mask -from .utility import rescale, write_data +from .utility import write_data def run( diff --git a/xfuse/convert/st.py b/xfuse/convert/st.py index 8eab7cde..e3b40ac8 100644 --- a/xfuse/convert/st.py +++ b/xfuse/convert/st.py @@ -5,11 +5,11 @@ import pandas as pd from PIL import Image +from ..utility.core import rescale from .utility import ( Spot, labels_from_spots, mask_tissue, - rescale, write_data, ) diff --git a/xfuse/convert/visium.py b/xfuse/convert/visium.py index 7a660ae6..2be53b57 100644 --- a/xfuse/convert/visium.py +++ b/xfuse/convert/visium.py @@ -6,11 +6,11 @@ from PIL import Image from scipy.sparse import csr_matrix +from ..utility.core import rescale from .utility import ( Spot, labels_from_spots, mask_tissue, - rescale, write_data, ) diff --git a/xfuse/utility/core.py b/xfuse/utility/core.py index b2bfbadd..ec0d10f0 100644 --- a/xfuse/utility/core.py +++ b/xfuse/utility/core.py @@ -1,10 +1,23 @@ -from typing import Any, ContextManager, Protocol, Tuple, TypeVar, Union +from typing import ( + Any, + ContextManager, + Protocol, + Tuple, + TypeVar, + Sequence, + Union, +) import warnings +import numpy as np +from PIL import Image + __all__ = [ "center_crop", + "rescale", + "resize", "temp_attr", ] @@ -47,6 +60,44 @@ def center_crop(x: ArrayType, target_shape: Tuple[int, ...]) -> ArrayType: ] +def rescale( + image: np.ndarray, scaling_factor: float, resample: int = Image.NEAREST +) -> np.ndarray: + r""" + Rescales image by a given `scaling_factor` + + :param image: Image array + :param scaling_factor: Scaling factor + :param resample: Resampling filter + :returns: The rescaled image + """ + image = Image.fromarray(image) + image = image.resize( + [round(x * scaling_factor) for x in image.size], resample=resample, + ) + image = np.array(image) + return image + + +def resize( + image: np.ndarray, + target_shape: Sequence[int], + resample: int = Image.NEAREST, +) -> np.ndarray: + r""" + Resizes image to a given `target_shape` + + :param image: Image array + :param target_shape: Target shape + :param resample: Resampling filter + :returns: The rescaled image + """ + image = Image.fromarray(image) + image = image.resize(target_shape[::-1], resample=resample) + image = np.array(image) + return image + + def temp_attr(obj: object, attr: str, value: Any) -> ContextManager: r""" Creates a context manager for setting transient object attributes. diff --git a/xfuse/utility/mask.py b/xfuse/utility/mask.py index a4f5e5e6..018289cc 100644 --- a/xfuse/utility/mask.py +++ b/xfuse/utility/mask.py @@ -3,9 +3,11 @@ import cv2 as cv import numpy as np +from PIL import Image from scipy.ndimage import label +from scipy.ndimage.morphology import binary_fill_holes -from .core import center_crop +from .core import rescale, resize from ..logging import INFO, log @@ -22,7 +24,6 @@ def remove_fg_elements(mask: np.ndarray, size_threshold: float): def compute_tissue_mask( image: np.ndarray, - initial_mask: Optional[np.ndarray] = None, convergence_threshold: float = 0.0001, size_threshold: float = 0.01, ) -> np.ndarray: @@ -32,16 +33,16 @@ def compute_tissue_mask( """ # pylint: disable=no-member # ^ pylint fails to identify cv.* members - if initial_mask is None: - initial_mask = np.zeros(image.shape[:2], dtype=np.bool) - initial_mask_center = center_crop( - initial_mask, - tuple(int(round(x * 0.8)) for x in iter(initial_mask.shape)), - ) - initial_mask_center[...] = True + original_shape = image.shape[:2] + scale_factor = 1000 / max(original_shape) + + image = rescale(image, scale_factor, resample=Image.NEAREST) + initial_mask = binary_fill_holes( + cv.blur(cv.Canny(cv.blur(image, (5, 5)), 100, 200), (5, 5)) + ) - mask = cv.GC_PR_BGD * np.ones(image.shape[:2], dtype=np.uint8) - mask[initial_mask] = cv.GC_PR_FGD + mask = np.where(initial_mask, cv.GC_PR_FGD, cv.GC_PR_BGD) + mask = mask.astype(np.uint8) bgd_model = np.zeros((1, 65), np.float64) fgd_model = bgd_model.copy() @@ -61,6 +62,8 @@ def compute_tissue_mask( mask = mask == cv.GC_PR_FGD mask = cleanup_mask(mask, size_threshold) + mask = resize(mask, target_shape=original_shape, resample=Image.NEAREST) + return mask diff --git a/xfuse/utility/visualization.py b/xfuse/utility/visualization.py index 544bb569..3a10484a 100644 --- a/xfuse/utility/visualization.py +++ b/xfuse/utility/visualization.py @@ -13,7 +13,8 @@ from ..data.slide import FullSlide, Slide from ..data.utility.misc import make_dataloader from ..session import Session, get, require -from ..utility.mask import center_crop, cleanup_mask +from ..utility.core import center_crop +from ..utility.mask import cleanup_mask __all__ = ["reduce_last_dimension", "visualize_metagenes"]