Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve masking #9

Merged
merged 3 commits into from
Nov 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion xfuse/convert/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import pandas as pd
from PIL import Image

from ..utility.core import rescale
from ..utility.mask import compute_tissue_mask
from .utility import rescale, write_data
from .utility import write_data


def run(
Expand Down
2 changes: 1 addition & 1 deletion xfuse/convert/st.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import pandas as pd
from PIL import Image

from ..utility.core import rescale
from .utility import (
Spot,
labels_from_spots,
mask_tissue,
rescale,
write_data,
)

Expand Down
2 changes: 1 addition & 1 deletion xfuse/convert/visium.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from PIL import Image
from scipy.sparse import csr_matrix

from ..utility.core import rescale
from .utility import (
Spot,
labels_from_spots,
mask_tissue,
rescale,
write_data,
)

Expand Down
53 changes: 52 additions & 1 deletion xfuse/utility/core.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
from typing import Any, ContextManager, Protocol, Tuple, TypeVar, Union
from typing import (
Any,
ContextManager,
Protocol,
Tuple,
TypeVar,
Sequence,
Union,
)

import warnings

import numpy as np
from PIL import Image


__all__ = [
"center_crop",
"rescale",
"resize",
"temp_attr",
]

Expand Down Expand Up @@ -47,6 +60,44 @@ def center_crop(x: ArrayType, target_shape: Tuple[int, ...]) -> ArrayType:
]


def rescale(
image: np.ndarray, scaling_factor: float, resample: int = Image.NEAREST
) -> np.ndarray:
r"""
Rescales image by a given `scaling_factor`

:param image: Image array
:param scaling_factor: Scaling factor
:param resample: Resampling filter
:returns: The rescaled image
"""
image = Image.fromarray(image)
image = image.resize(
[round(x * scaling_factor) for x in image.size], resample=resample,
)
image = np.array(image)
return image


def resize(
image: np.ndarray,
target_shape: Sequence[int],
resample: int = Image.NEAREST,
) -> np.ndarray:
r"""
Resizes image to a given `target_shape`

:param image: Image array
:param target_shape: Target shape
:param resample: Resampling filter
:returns: The rescaled image
"""
image = Image.fromarray(image)
image = image.resize(target_shape[::-1], resample=resample)
image = np.array(image)
return image


def temp_attr(obj: object, attr: str, value: Any) -> ContextManager:
r"""
Creates a context manager for setting transient object attributes.
Expand Down
44 changes: 29 additions & 15 deletions xfuse/utility/mask.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import itertools as it
from typing import Optional
import warnings

import cv2 as cv
import numpy as np
from PIL import Image
from scipy.ndimage import label
from scipy.ndimage.morphology import binary_fill_holes

from .core import center_crop
from .core import rescale, resize
from ..logging import INFO, log


Expand All @@ -22,7 +24,6 @@ def remove_fg_elements(mask: np.ndarray, size_threshold: float):

def compute_tissue_mask(
image: np.ndarray,
initial_mask: Optional[np.ndarray] = None,
convergence_threshold: float = 0.0001,
size_threshold: float = 0.01,
) -> np.ndarray:
Expand All @@ -32,16 +33,16 @@ def compute_tissue_mask(
"""
# pylint: disable=no-member
# ^ pylint fails to identify cv.* members
if initial_mask is None:
initial_mask = np.zeros(image.shape[:2], dtype=np.bool)
initial_mask_center = center_crop(
initial_mask,
tuple(int(round(x * 0.8)) for x in iter(initial_mask.shape)),
)
initial_mask_center[...] = True
original_shape = image.shape[:2]
scale_factor = 1000 / max(original_shape)

mask = cv.GC_PR_BGD * np.ones(image.shape[:2], dtype=np.uint8)
mask[initial_mask] = cv.GC_PR_FGD
image = rescale(image, scale_factor, resample=Image.NEAREST)
initial_mask = binary_fill_holes(
cv.blur(cv.Canny(cv.blur(image, (5, 5)), 100, 200), (5, 5))
)

mask = np.where(initial_mask, cv.GC_PR_FGD, cv.GC_PR_BGD)
mask = mask.astype(np.uint8)

bgd_model = np.zeros((1, 65), np.float64)
fgd_model = bgd_model.copy()
Expand All @@ -50,9 +51,20 @@ def compute_tissue_mask(

for i in it.count(1):
old_mask = mask.copy()
cv.grabCut(
image, mask, None, bgd_model, fgd_model, 1, cv.GC_INIT_WITH_MASK,
)
try:
cv.grabCut(
image,
mask,
None,
bgd_model,
fgd_model,
1,
cv.GC_INIT_WITH_MASK,
)
except cv.error as cv_err:
warnings.warn(f"Failed to mask tissue\n{str(cv_err).strip()}")
mask = np.full_like(mask, cv.GC_PR_FGD)
break
prop_changed = (mask != old_mask).sum() / np.prod(mask.shape)
log(INFO, f" Iteration {i}: {prop_changed=}")
if prop_changed < convergence_threshold:
Expand All @@ -61,6 +73,8 @@ def compute_tissue_mask(
mask = mask == cv.GC_PR_FGD
mask = cleanup_mask(mask, size_threshold)

mask = resize(mask, target_shape=original_shape, resample=Image.NEAREST)

return mask


Expand Down
3 changes: 2 additions & 1 deletion xfuse/utility/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from ..data.slide import FullSlide, Slide
from ..data.utility.misc import make_dataloader
from ..session import Session, get, require
from ..utility.mask import center_crop, cleanup_mask
from ..utility.core import center_crop
from ..utility.mask import cleanup_mask


__all__ = ["reduce_last_dimension", "visualize_metagenes"]
Expand Down