Skip to content

Commit

Permalink
improve masking
Browse files Browse the repository at this point in the history
  • Loading branch information
ludvb committed Nov 1, 2020
1 parent 1d33b7b commit 62e9c10
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 16 deletions.
3 changes: 2 additions & 1 deletion xfuse/convert/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import pandas as pd
from PIL import Image

from ..utility.core import rescale
from ..utility.mask import compute_tissue_mask
from .utility import rescale, write_data
from .utility import write_data


def run(
Expand Down
2 changes: 1 addition & 1 deletion xfuse/convert/st.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import pandas as pd
from PIL import Image

from ..utility.core import rescale
from .utility import (
Spot,
labels_from_spots,
mask_tissue,
rescale,
write_data,
)

Expand Down
2 changes: 1 addition & 1 deletion xfuse/convert/visium.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from PIL import Image
from scipy.sparse import csr_matrix

from ..utility.core import rescale
from .utility import (
Spot,
labels_from_spots,
mask_tissue,
rescale,
write_data,
)

Expand Down
53 changes: 52 additions & 1 deletion xfuse/utility/core.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
from typing import Any, ContextManager, Protocol, Tuple, TypeVar, Union
from typing import (
Any,
ContextManager,
Protocol,
Tuple,
TypeVar,
Sequence,
Union,
)

import warnings

import numpy as np
from PIL import Image


__all__ = [
"center_crop",
"rescale",
"resize",
"temp_attr",
]

Expand Down Expand Up @@ -47,6 +60,44 @@ def center_crop(x: ArrayType, target_shape: Tuple[int, ...]) -> ArrayType:
]


def rescale(
image: np.ndarray, scaling_factor: float, resample: int = Image.NEAREST
) -> np.ndarray:
r"""
Rescales image by a given `scaling_factor`
:param image: Image array
:param scaling_factor: Scaling factor
:param resample: Resampling filter
:returns: The rescaled image
"""
image = Image.fromarray(image)
image = image.resize(
[round(x * scaling_factor) for x in image.size], resample=resample,
)
image = np.array(image)
return image


def resize(
image: np.ndarray,
target_shape: Sequence[int],
resample: int = Image.NEAREST,
) -> np.ndarray:
r"""
Resizes image to a given `target_shape`
:param image: Image array
:param target_shape: Target shape
:param resample: Resampling filter
:returns: The rescaled image
"""
image = Image.fromarray(image)
image = image.resize(target_shape[::-1], resample=resample)
image = np.array(image)
return image


def temp_attr(obj: object, attr: str, value: Any) -> ContextManager:
r"""
Creates a context manager for setting transient object attributes.
Expand Down
25 changes: 14 additions & 11 deletions xfuse/utility/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

import cv2 as cv
import numpy as np
from PIL import Image
from scipy.ndimage import label
from scipy.ndimage.morphology import binary_fill_holes

from .core import center_crop
from .core import rescale, resize
from ..logging import INFO, log


Expand All @@ -22,7 +24,6 @@ def remove_fg_elements(mask: np.ndarray, size_threshold: float):

def compute_tissue_mask(
image: np.ndarray,
initial_mask: Optional[np.ndarray] = None,
convergence_threshold: float = 0.0001,
size_threshold: float = 0.01,
) -> np.ndarray:
Expand All @@ -32,16 +33,16 @@ def compute_tissue_mask(
"""
# pylint: disable=no-member
# ^ pylint fails to identify cv.* members
if initial_mask is None:
initial_mask = np.zeros(image.shape[:2], dtype=np.bool)
initial_mask_center = center_crop(
initial_mask,
tuple(int(round(x * 0.8)) for x in iter(initial_mask.shape)),
)
initial_mask_center[...] = True
original_shape = image.shape[:2]
scale_factor = 1000 / max(original_shape)

image = rescale(image, scale_factor, resample=Image.NEAREST)
initial_mask = binary_fill_holes(
cv.blur(cv.Canny(cv.blur(image, (5, 5)), 100, 200), (5, 5))
)

mask = cv.GC_PR_BGD * np.ones(image.shape[:2], dtype=np.uint8)
mask[initial_mask] = cv.GC_PR_FGD
mask = np.where(initial_mask, cv.GC_PR_FGD, cv.GC_PR_BGD)
mask = mask.astype(np.uint8)

bgd_model = np.zeros((1, 65), np.float64)
fgd_model = bgd_model.copy()
Expand All @@ -61,6 +62,8 @@ def compute_tissue_mask(
mask = mask == cv.GC_PR_FGD
mask = cleanup_mask(mask, size_threshold)

mask = resize(mask, target_shape=original_shape, resample=Image.NEAREST)

return mask


Expand Down
3 changes: 2 additions & 1 deletion xfuse/utility/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from ..data.slide import FullSlide, Slide
from ..data.utility.misc import make_dataloader
from ..session import Session, get, require
from ..utility.mask import center_crop, cleanup_mask
from ..utility.core import center_crop
from ..utility.mask import cleanup_mask


__all__ = ["reduce_last_dimension", "visualize_metagenes"]
Expand Down

0 comments on commit 62e9c10

Please sign in to comment.