Skip to content

Commit

Permalink
SC_28 update filter shapes (#142)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneDefauw authored Oct 5, 2023
1 parent 8b3d10e commit 3de0ec2
Show file tree
Hide file tree
Showing 12 changed files with 449 additions and 200 deletions.
1 change: 0 additions & 1 deletion src/napari_sparrow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Define package version"""
__version__ = "0.0.1"
# TODO is __version__ needed/used?

import os
os.environ["USE_PYGEOS"] = "0"
Expand Down
22 changes: 16 additions & 6 deletions src/napari_sparrow/image/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from spatialdata import SpatialData
from spatialdata.models import SpatialElement
from spatialdata.models.models import ScaleFactors_t
from spatialdata.transformations import BaseTransformation
from spatialdata.transformations import BaseTransformation, Translation, Identity
from spatialdata.transformations._utils import (
_get_transformations,
_get_transformations_xarray,
Expand Down Expand Up @@ -59,10 +59,21 @@ def _get_boundary(
def _get_translation(
spatial_image: Union[SpatialImage, MultiscaleSpatialImage, DataArray]
) -> Tuple[float, float]:
transform_matrix = _get_transformation(spatial_image).to_affine_matrix(
input_axes=("x", "y"), output_axes=("x", "y")
)

translation=_get_transformation(spatial_image)

if not isinstance( translation, (Translation, Identity) ):
raise ValueError( f"Currently only transformations of type Translation are supported, "
f"while transformation associated with {spatial_image} is of type {type(translation)}.")

return _get_translation_values( translation )


def _get_translation_values( translation: Union[Translation, Identity]):
transform_matrix=translation.to_affine_matrix(
input_axes=("x", "y"), output_axes=("x", "y")
)

if (
transform_matrix[0, 0] == 1.0
and transform_matrix[0, 1] == 0.0
Expand All @@ -75,8 +86,7 @@ def _get_translation(
return transform_matrix[0, 2], transform_matrix[1, 2]
else:
raise ValueError(
f"The provided transform matrix '{transform_matrix}' associated with the SpatialImage "
f"element with name '{spatial_image.name}' represents more than just a translation, which is not currently supported."
f"The provided transform matrix {transform_matrix} represents more than just a translation."
)


Expand Down
25 changes: 9 additions & 16 deletions src/napari_sparrow/image/_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@

import dask.array as da
import numpy as np
import spatialdata
import torch
from cellpose import models
from dask.array import Array
from dask.array.overlap import coerce_depth, ensure_minimum_chunksize
from numpy.typing import NDArray
from shapely.affinity import translate
from spatialdata import SpatialData
from spatialdata.models.models import ScaleFactors_t
from spatialdata.transformations import Translation
Expand All @@ -21,7 +19,7 @@
_get_translation,
_substract_translation_crd,
)
from napari_sparrow.shape._shape import _mask_image_to_polygons
from napari_sparrow.shape._shape import _add_shapes_layer
from napari_sparrow.utils.pylogger import get_pylogger

log = get_pylogger(__name__)
Expand Down Expand Up @@ -194,7 +192,7 @@ def _segment_img_layer(

translation = Translation([tx, ty], axes=("x", "y"))

sdata=_add_label_layer(
sdata = _add_label_layer(
sdata,
arr=x_labels,
output_layer=output_labels_layer,
Expand All @@ -204,20 +202,16 @@ def _segment_img_layer(
overwrite=overwrite,
)

# only calculate shapes layer if is specified
# only calculate shapes layer if it is specified
if output_shapes_layer is not None:
se_labels = _get_spatial_element(sdata, layer=output_labels_layer)
# now calculate the polygons
polygons = _mask_image_to_polygons(mask=se_labels.data)

x_translation, y_translation = _get_translation(se_labels)
polygons["geometry"] = polygons["geometry"].apply(
lambda geom: translate(geom, xoff=x_translation, yoff=y_translation)
)

sdata.add_shapes(
name=output_shapes_layer,
shapes=spatialdata.models.ShapesModel.parse(polygons),
# convert the labels to polygons and add them as shapes layer to sdata
sdata = _add_shapes_layer(
sdata,
input=se_labels.data,
output_layer=output_shapes_layer,
transformation=translation,
overwrite=overwrite,
)

Expand Down Expand Up @@ -463,7 +457,6 @@ def _trim_masks(masks: Array, depth: Dict[int, int]) -> Array:

# now convert back to non-overlapping coordinates.

# check if this is correct TODO, thinks so
y_offset = max(0, y_start - (chunk_id[0] * 2 * depth[0] + depth[0]))
x_offset = max(0, x_start - (chunk_id[1] * 2 * depth[1] + depth[1]))

Expand Down
16 changes: 15 additions & 1 deletion src/napari_sparrow/io/_transcripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
def read_resolve_transcripts(
sdata: SpatialData,
path_count_matrix: str | Path,
overwrite: bool = False,
) -> SpatialData:
"""
Reads and adds Resolve transcript information to a SpatialData object.
Expand All @@ -26,6 +27,8 @@ def read_resolve_transcripts(
path_count_matrix : str | Path
Path to the file containing the transcripts information specific to Resolve.
Expected to contain x, y coordinates and a gene name.
overwrite: bool, default=False
If True overwrites the element (points layer) if it already exists.
Returns
-------
Expand All @@ -39,6 +42,7 @@ def read_resolve_transcripts(
"column_gene": 3,
"delimiter": "\t",
"header": None,
"overwrite": overwrite,
}

sdata = read_transcripts(*args, **kwargs)
Expand All @@ -49,6 +53,7 @@ def read_vizgen_transcripts(
sdata: SpatialData,
path_count_matrix: str | Path,
path_transform_matrix: str | Path,
overwrite: bool = False,
) -> SpatialData:
"""
Reads and adds Vizgen transcript information to a SpatialData object.
Expand All @@ -62,6 +67,8 @@ def read_vizgen_transcripts(
Expected to contain x, y coordinates and a gene name.
path_transform_matrix : str | Path
Path to the transformation matrix for the affine transformation.
overwrite: bool, default=False
If True overwrites the element (points layer) if it already exists.
Returns
-------
Expand All @@ -75,6 +82,7 @@ def read_vizgen_transcripts(
"column_gene": 8,
"delimiter": ",",
"header": 0,
"overwrite": overwrite,
}

sdata = read_transcripts(*args, **kwargs)
Expand All @@ -84,6 +92,7 @@ def read_vizgen_transcripts(
def read_stereoseq_transcripts(
sdata: SpatialData,
path_count_matrix: str | Path,
overwrite: bool = False,
) -> SpatialData:
"""
Reads and adds Stereoseq transcript information to a SpatialData object.
Expand All @@ -95,6 +104,8 @@ def read_stereoseq_transcripts(
path_count_matrix : str | Path
Path to the file containing the transcripts information specific to Stereoseq.
Expected to contain x, y coordinates, gene name, and a midcount column.
overwrite: bool, default=False
If True overwrites the element (points layer) if it already exists.
Returns
-------
Expand All @@ -109,6 +120,7 @@ def read_stereoseq_transcripts(
"column_midcount": 3,
"delimiter": ",",
"header": 0,
"overwrite": overwrite,
}

sdata = read_transcripts(*args, **kwargs)
Expand Down Expand Up @@ -222,7 +234,9 @@ def transform_coordinates(df):
# Reorder
transformed_ddf = transformed_ddf[["pixel_x", "pixel_y", "gene"]]

sdata = _add_transcripts_to_sdata(sdata, transformed_ddf, points_layer, overwrite=overwrite)
sdata = _add_transcripts_to_sdata(
sdata, transformed_ddf, points_layer, overwrite=overwrite
)

return sdata

Expand Down
51 changes: 29 additions & 22 deletions src/napari_sparrow/plot/_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from napari_sparrow.image._image import (
_apply_transform,
_get_boundary,
_unapply_transform,
_get_spatial_element,
_unapply_transform,
)
from napari_sparrow.shape import intersect_rectangles
from napari_sparrow.utils.pylogger import get_pylogger
Expand Down Expand Up @@ -57,8 +57,8 @@ def plot_image(

def plot_shapes(
sdata: SpatialData,
img_layer: str | Iterable[str] = None,
shapes_layer: str | Iterable[str] = None,
img_layer: Optional[str | Iterable[str]] = None,
shapes_layer: Optional[str | Iterable[str]] = None,
channel: Optional[int | Iterable[int]] = None,
crd: Optional[Tuple[int, int, int, int]] = None,
figsize: Optional[Tuple[int, int]] = None,
Expand Down Expand Up @@ -165,7 +165,7 @@ def plot_shapes(

# if channel is None, get the number of channels from the first img_layer given, maybe print a message about this.
if channel is None:
se=_get_spatial_element( sdata, layer=img_layer[0] )
se = _get_spatial_element(sdata, layer=img_layer[0])
channels = se.c.data
else:
channels = channel
Expand Down Expand Up @@ -222,7 +222,7 @@ def _plot_shapes( # FIXME: rename, this does not always plot a shapes layer any
crd: Tuple[int, int, int, int] = None,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
plot_filtered: bool = False,
shapes_layer_filtered: Optional[str | Iterable[str]] = None,
img_title: bool = False,
shapes_title: bool = False,
channel_title: bool = True,
Expand Down Expand Up @@ -255,8 +255,8 @@ def _plot_shapes( # FIXME: rename, this does not always plot a shapes layer any
Lower bound for color scale for continuous data. Given as a percentile.
vmax : float or None, optional
Upper bound for color scale for continuous data. Given as a percentile.
plot_filtered : bool, default=False
Whether to plot the cells that were filtered out in previous steps.
shapes_layer_filtered : str or Iterable[str], optional
Extra shapes layers to plot. E.g. shapes filtered out in previous preprocessing steps.
img_title: bool, default=False
A flag indicating whether the image layer's name should be added to the title of the plot.
shapes_title: bool, default=False
Expand All @@ -278,7 +278,15 @@ def _plot_shapes( # FIXME: rename, this does not always plot a shapes layer any
if img_layer is None:
img_layer = [*sdata.images][-1]

se = _get_spatial_element( sdata, layer=img_layer )
if shapes_layer_filtered is not None:
shapes_layer_filtered = (
list(shapes_layer_filtered)
if isinstance(shapes_layer_filtered, Iterable)
and not isinstance(shapes_layer_filtered, str)
else [shapes_layer_filtered]
)

se = _get_spatial_element(sdata, layer=img_layer)

# Update coords
se, x_coords_orig, y_coords_orig = _apply_transform(se)
Expand Down Expand Up @@ -350,8 +358,8 @@ def _plot_shapes( # FIXME: rename, this does not always plot a shapes layer any
x=slice(crd[0], crd[1]), y=slice(crd[2], crd[3])
).plot.imshow(cmap="gray", robust=True, ax=ax, add_colorbar=False)

if shapes_layer:
sdata[shapes_layer].cx[crd[0] : crd[1], crd[2] : crd[3]].plot(
if shapes_layer is not None:
sdata.shapes[shapes_layer].cx[crd[0] : crd[1], crd[2] : crd[3]].plot(
ax=ax,
edgecolor="white",
column=column,
Expand All @@ -363,18 +371,17 @@ def _plot_shapes( # FIXME: rename, this does not always plot a shapes layer any
vmax=vmax, # np.percentile(column,vmax),
vmin=vmin, # np.percentile(column,vmin)
)
if plot_filtered:
for i in [*sdata.shapes]:
if f"filtered_{shapes_layer}" in i:
sdata[i].cx[crd[0] : crd[1], crd[2] : crd[3]].plot(
ax=ax,
edgecolor="red",
linewidth=1,
alpha=alpha,
legend=True,
aspect=1,
cmap="gray",
)
if shapes_layer_filtered is not None:
for i in shapes_layer_filtered:
sdata.shapes[i].cx[crd[0] : crd[1], crd[2] : crd[3]].plot(
ax=ax,
edgecolor="red",
linewidth=1,
alpha=alpha,
legend=True,
aspect=1,
cmap="gray",
)
ax.axes.set_aspect(aspect)
ax.set_xlim(crd[0], crd[1])
ax.set_ylim(crd[2], crd[3])
Expand Down
2 changes: 1 addition & 1 deletion src/napari_sparrow/shape/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from ._cell_expansion import create_voronoi_boundaries
from ._shape import intersect_rectangles
from ._shape import intersect_rectangles
40 changes: 35 additions & 5 deletions src/napari_sparrow/shape/_cell_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,37 @@ def create_voronoi_boundaries(
sdata: SpatialData,
radius: int = 0,
shapes_layer: str = "segmentation_mask_boundaries",
):
) -> SpatialData:
"""
Create Voronoi boundaries from the shapes layer of the provided SpatialData object.
Given spatial data and a radius, this function calculates Voronoi boundaries
and expands these boundaries based on the radius.
Parameters
----------
sdata : SpatialData
The spatial data object on which Voronoi boundaries will be created.
radius : int, optional
The expansion radius for the Voronoi boundaries, by default 0.
If provided, Voronoi boundaries will be expanded by this radius.
Must be non-negative.
shapes_layer : str, optional
The name of the layer in `sdata` representing shapes used to derive
Voronoi boundaries. Default is "segmentation_mask_boundaries".
Returns
-------
SpatialData
Modified `sdata` object with the Voronoi boundaries created and
possibly expanded.
Raises
------
ValueError
If the provided radius is negative.
"""

if radius < 0:
raise ValueError(
f"radius should be >0, provided value for radius is '{radius}'"
Expand All @@ -22,16 +52,16 @@ def create_voronoi_boundaries(
# sdata[shape_layer].index = list(map(str, sdata[shape_layer].index))

si = sdata[[*sdata.images][0]]
# need to add translation in x and y direction to size of the image,
# need to add translation in x and y direction to size of the image,
# to account for use case where si is already cropped
tx, ty = _get_translation(si)

boundary = Polygon(
[
(tx, ty),
(tx+si.sizes["x"], ty),
(tx+si.sizes["x"], ty+si.sizes["y"]),
(tx, ty+si.sizes["y"]),
(tx + si.sizes["x"], ty),
(tx + si.sizes["x"], ty + si.sizes["y"]),
(tx, ty + si.sizes["y"]),
]
)

Expand Down
Loading

0 comments on commit 3de0ec2

Please sign in to comment.