Skip to content

Commit

Permalink
formatter and linting
Browse files Browse the repository at this point in the history
  • Loading branch information
vreuter committed Mar 17, 2024
1 parent 6fa533b commit cb53ecb
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 37 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ profile = "black"

[tool.mypy]
files = ['*.py', 'spotfishing/*.py', 'tests/*.py']
exclude = ['^deprecated\.py$']
plugins = ["pydantic.mypy"]
warn_redundant_casts = true
warn_unused_ignores = true
Expand Down
13 changes: 13 additions & 0 deletions spotfishing/_numeric_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Aliases for numeric types"""

from typing import Union
import numpy as np

__author__ = "Vince Reuter"
__credits__ = ["Vince Reuter"]

__all__ = ["NumpyInt", "NumpyFloat", "PixelValue"]

NumpyInt = Union[np.int8, np.int16, np.int32, np.int64]
NumpyFloat = Union[np.float16, np.float32, np.float64]
PixelValue = Union[np.int8, np.int16]
9 changes: 5 additions & 4 deletions spotfishing/detection_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from typing import TYPE_CHECKING, Iterable

if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
import pandas as pd

from ._constants import *
from ._exceptions import DimensionalityError
from ._numeric_types import *

__author__ = "Vince Reuter"
__all__ = ["DetectionResult"]
Expand Down Expand Up @@ -48,12 +49,12 @@ class DetectionResult:
"""

table: "pd.DataFrame"
image: "np.ndarray"
labels: "np.ndarray"
image: "npt.NDArray[PixelValue]"
labels: "npt.NDArray[NumpyInt]"

def __post_init__(self) -> None:
"""Validate that the structure and values of the inputs are as required."""
errors = []
errors: list[Exception] = []
cols = list(self.table.columns)
if cols != DETECTION_RESULT_TABLE_COLUMNS:
errors.append(IllegalDetectionResultTableColumns(observed_columns=cols))
Expand Down
67 changes: 40 additions & 27 deletions spotfishing/detectors.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
"""Different spot detection implementations"""

from typing import Optional, Tuple, Union
from typing import TYPE_CHECKING, Optional, Tuple, Union

import numpy as np

if TYPE_CHECKING:
import numpy.typing as npt
import pandas as pd
from scipy import ndimage as ndi
from skimage.filters import gaussian
from skimage.filters import gaussian # type: ignore[import-untyped]
from skimage.measure import regionprops_table
from skimage.morphology import ball, remove_small_objects, white_tophat
from skimage.segmentation import expand_labels
from skimage.segmentation import expand_labels # type: ignore[import-untyped]

from ._constants import *
from ._exceptions import DimensionalityError
from ._numeric_types import *
from .detection_result import (
ROI_CENTROID_COLUMN_RENAMING,
SKIMAGE_REGIONPROPS_TABLE_COLUMNS_EXPANDED,
Expand All @@ -24,13 +28,13 @@
__all__ = ["detect_spots_dog", "detect_spots_int"]

Numeric = Union[int, float]
NumpyInt = Union[np.int8, np.int16, np.int32, np.int64]
NumpyFloat = Union[np.float16, np.float32, np.float64]
PixelValue = Union[np.int8, np.int16]


def detect_spots_dog(
*, input_image, spot_threshold: Numeric, expand_px: Optional[Numeric]
*,
input_image: npt.NDArray[PixelValue],
spot_threshold: Numeric,
expand_px: Optional[Numeric],
) -> DetectionResult:
"""Spot detection by difference of Gaussians filter
Expand All @@ -51,15 +55,18 @@ def detect_spots_dog(
"""
_check_input_image(input_image)
img = _preprocess_for_difference_of_gaussians(input_image)
labels, _ = ndi.label(img > spot_threshold)
labels, _ = ndi.label(img > spot_threshold) # type: ignore[no-untyped-call]
spot_props, labels = _build_props_table(
labels=labels, input_image=input_image, expand_px=expand_px
)
return DetectionResult(table=spot_props, image=img, labels=labels)


def detect_spots_int(
*, input_image, spot_threshold: Numeric, expand_px: Optional[Numeric]
*,
input_image: npt.NDArray[PixelValue],
spot_threshold: Numeric,
expand_px: Optional[Numeric],
) -> DetectionResult:
"""Spot detection by intensity filter
Expand All @@ -82,9 +89,9 @@ def detect_spots_int(
# See: https://github.com/gerlichlab/looptrace/issues/125
_check_input_image(input_image)
binary = input_image > spot_threshold
binary = ndi.binary_fill_holes(binary)
struct = ndi.generate_binary_structure(input_image.ndim, 2)
labels, num_obj = ndi.label(binary, structure=struct)
binary = ndi.binary_fill_holes(binary) # type: ignore[no-untyped-call]
struct = ndi.generate_binary_structure(input_image.ndim, 2) # type: ignore[no-untyped-call]
labels, num_obj = ndi.label(binary, structure=struct) # type: ignore[no-untyped-call]
if num_obj > 1:
labels = remove_small_objects(labels, min_size=5)
spot_props, labels = _build_props_table(
Expand All @@ -94,32 +101,36 @@ def detect_spots_int(


def _build_props_table(
*, labels: np.ndarray[NumpyInt], input_image: np.ndarray[PixelValue], expand_px: Optional[int]
) -> Tuple[pd.DataFrame, np.ndarray[NumpyInt]]:
*,
labels: npt.NDArray[NumpyInt],
input_image: npt.NDArray[PixelValue],
expand_px: Optional[Numeric],
) -> Tuple[pd.DataFrame, npt.NDArray[NumpyInt]]:
if expand_px:
labels = expand_labels(labels, expand_px)
if np.all(labels == 0):
# No substructures (ROIs) exist.
spot_props = pd.DataFrame(columns=SKIMAGE_REGIONPROPS_TABLE_COLUMNS_EXPANDED)
else:
spot_props = regionprops_table(
label_image=labels,
intensity_image=input_image,
properties=(
ROI_LABEL_KEY,
ROI_CENTROID_KEY,
ROI_AREA_KEY,
ROI_MEAN_INTENSITY_KEY,
),
spot_props = pd.DataFrame(
regionprops_table(
label_image=labels,
intensity_image=input_image,
properties=(
ROI_LABEL_KEY,
ROI_CENTROID_KEY,
ROI_AREA_KEY,
ROI_MEAN_INTENSITY_KEY,
),
)
)
spot_props = pd.DataFrame(spot_props)
spot_props = spot_props.drop(["label"], axis=1, errors="ignore")
spot_props = spot_props.rename(columns=dict(ROI_CENTROID_COLUMN_RENAMING))
spot_props = spot_props.reset_index(drop=True)
return spot_props, labels


def _check_input_image(img: np.ndarray[PixelValue]) -> None:
def _check_input_image(img: npt.NDArray[PixelValue]) -> None:
if not isinstance(img, np.ndarray):
raise TypeError(
f"Expected numpy array for input image but got {type(img).__name__}"
Expand All @@ -130,9 +141,11 @@ def _check_input_image(img: np.ndarray[PixelValue]) -> None:
)


def _preprocess_for_difference_of_gaussians(input_image: np.ndarray[PixelValue]) -> np.ndarray[PixelValue]:
def _preprocess_for_difference_of_gaussians(
input_image: npt.NDArray[PixelValue],
) -> npt.NDArray[PixelValue]:
img = white_tophat(image=input_image, footprint=ball(2))
img = gaussian(img, 0.8) - gaussian(img, 1.3)
img = img / gaussian(input_image, 3)
img = (img - np.mean(img)) / np.std(img)
return img
return img # type: ignore[no-any-return]
6 changes: 4 additions & 2 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
__author__ = "Vince Reuter"
__credits__ = ["Vince Reuter"]

__all__ = ["load_image_file", ]
__all__ = [
"load_image_file",
]


def get_img_data_file(fn: str) -> Path:
Expand All @@ -18,4 +20,4 @@ def get_img_data_file(fn: str) -> Path:

def load_image_file(fn: str) -> np.ndarray:
"""Load an input image from disk, with the given name and stored in test data inputs folder."""
return np.load(get_img_data_file(fn)) # type: ignore
return np.load(get_img_data_file(fn)) # type: ignore
5 changes: 2 additions & 3 deletions tests/test_accord_with_original_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@

@pytest.mark.skip("not implemented")
@pytest.mark.parametrize(
["detect", "kwargs"],
[(detect_spots_dog, {}), (detect_spots_int, {})]
)
["detect", "kwargs"], [(detect_spots_dog, {}), (detect_spots_int, {})]
)
@pytest.mark.parametrize("input_image", [])
def test_output_is_correct_with_original_settings(detect, input_image, kwargs):
detect(input_image=input_image, **kwargs)
4 changes: 3 additions & 1 deletion tests/test_old_new_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def get_img_data_file(fn: str) -> Path:
)
def test_eqv(input_image, old_fun, new_fun, threshold, expand_px):
old_table, _, _ = old_fun(input_image, threshold)
new_res = new_fun(input_image=input_image, spot_threshold=threshold, expand_px=expand_px)
new_res = new_fun(
input_image=input_image, spot_threshold=threshold, expand_px=expand_px
)
new_table = new_res.table
assert np.all(old_table.index == new_table.index)
cols = ["zc", "yc", "xc", "area", "intensity_mean"]
Expand Down

0 comments on commit cb53ecb

Please sign in to comment.