Skip to content

Commit

Permalink
SC_25 support for mac gpu + crop support apply (#139)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneDefauw authored Sep 28, 2023
1 parent 4838cea commit a6b3689
Show file tree
Hide file tree
Showing 14 changed files with 87 additions and 52 deletions.
4 changes: 2 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dependencies:
- opencv=4.5.5
- pip=22.2.2
- python=3.10.8
- pytorch=1.12.1
- pytorch=1.13.0
- rasterio=1.3.2
- scanpy=1.9.4
- scipy=1.8.0
Expand All @@ -18,7 +18,7 @@ dependencies:
- pip:
- anndata==0.9.2
- basicpy==1.0.0
- cellpose==2.0.5
- cellpose==2.2.3
- jax==0.4.6
- jaxlib==0.4.6
- shapely==2.0.1
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ packages = find:
install_requires =
pandas==1.4.3,
spatialdata>=0.0.9
cellpose==2.0.5
cellpose>=2.0.5
squidpy==1.2.0
#matplotlib<3.7 # scanpy not compatible with matplotlib>=3.7
scanpy>=1.9.1
Expand Down
40 changes: 29 additions & 11 deletions src/napari_sparrow/image/_apply.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from types import MappingProxyType
from typing import Any, Callable, Iterable, Mapping, Optional
from typing import Any, Callable, Iterable, Mapping, Optional, Tuple

import dask.array as da
import spatialdata
from dask.array import Array
from dask.array.overlap import coerce_depth
from numpy.typing import NDArray
from spatialdata import SpatialData
from spatialdata.transformations import get_transformation, set_transformation
from spatialdata.transformations import Translation, set_transformation

from napari_sparrow.image._image import (_get_translation,
_substract_translation_crd)
from napari_sparrow.utils.pylogger import get_pylogger

log = get_pylogger(__name__)
Expand All @@ -21,6 +23,7 @@ def apply(
output_layer: Optional[str] = None,
channel: Optional[int | Iterable[int]] = None,
chunks: str | tuple[int, int] | int | None = None,
crd: Optional[Tuple[int, int, int, int]] = None,
overwrite: bool = False,
fn_kwargs: Mapping[str, Any] = MappingProxyType({}),
**kwargs: Any,
Expand All @@ -44,6 +47,8 @@ def apply(
chunks : str | tuple[int, int] | int | None, default=None
Specification for rechunking the data before applying the function.
If specified, dask's map_overlap or map_blocks is used depending on the occurence of the "depth" parameter in kwargs.
crd : Optional[Tuple[int, int, int, int]], default=None
The coordinates specifying the region of the image to be processed. Defines the bounds (x_min, x_max, y_min, y_max).
overwrite : bool, default=False
If True, overwrites the output layer if it already exists in `sdata`.
fn_kwargs : Mapping[str, Any], default=MappingProxyType({})
Expand Down Expand Up @@ -79,13 +84,13 @@ def apply(
>>> def my_function( image, parameter ):
... return image*parameter
>>> fn_kwargs={ "parameter": ChannelList( [2,3] ) }
>>> sdata = apply(sdata, my_function, img_layer="raw_image", output_layer="processed_image", channel=None, fn_kwargs=fn_kwargs )
>>> fn_kwargs={ "parameter": ChannelList( [2,3] ) }
>>> sdata = apply(sdata, my_function, img_layer="raw_image", output_layer="processed_image", channel=None, fn_kwargs=fn_kwargs)
Apply the same function to only the first channel of the image:
>>> fn_kwargs={ "parameter": 2 }
>>> sdata = apply(sdata, my_function, img_layer="raw_image", output_layer="processed_image", channel=0, fn_kwargs=fn_kwargs )
>>> sdata = apply(sdata, my_function, img_layer="raw_image", output_layer="processed_image", channel=0, fn_kwargs=fn_kwargs)
"""

if img_layer is None:
Expand All @@ -104,6 +109,10 @@ def apply_func(
fn_kwargs: Mapping[str, Any] = MappingProxyType({}),
) -> Array:
if chunks is None:
# if dask array, we want to rechunk,
# because taking a crop could have caused irregular chunks
if isinstance( arr, Array ):
arr=arr.rechunk( arr.chunksize )
arr = func(arr, **fn_kwargs)
return da.asarray(arr)
arr = da.asarray(arr).rechunk(chunks)
Expand Down Expand Up @@ -165,26 +174,35 @@ def adjust_depth(depth, chunksize, depth_dim):
# store results per channel
results = []

if crd:
crd = _substract_translation_crd(sdata[img_layer], crd)

for ch, _fn_kwargs in zip(channel, _fn_kwargs_channel):
arr = sdata[img_layer].isel(c=ch).data
if len(arr.shape) != 2:
raise ValueError(
f"Array is of dimension {arr.shape}, currently only 2D images are supported."
)
# need to pass correct value from fn_kwargs to apply_func
if crd:
arr = arr[crd[2] : crd[3], crd[0] : crd[1]]
# passing correct value from fn_kwargs to apply_func
arr = apply_func(func, arr, _fn_kwargs)
results.append(arr)

arr = da.stack(results, axis=0)

spatial_image = spatialdata.models.Image2DModel.parse(arr, dims=("c", "y", "x"))

# TODO maybe also make it possible to send transformation with the apply function
# now by default we copy transformation of old img_layer to new img_layer
trf = get_transformation(sdata[img_layer])
set_transformation(spatial_image, trf)
tx, ty = _get_translation(sdata[img_layer])

if crd:
tx = tx + crd[0]
ty = ty + crd[2]

translation = Translation([tx, ty], axes=("x", "y"))

set_transformation(spatial_image, translation)

# during adding of image it is written to zarr store
sdata.add_image(name=output_layer, image=spatial_image, overwrite=overwrite)

return sdata
Expand Down
5 changes: 4 additions & 1 deletion src/napari_sparrow/image/_contrast.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def enhance_contrast(
chunks: Optional[str | tuple[int, int] | int] = 10000,
depth: Tuple[int, int] | Dict[int, int] | int = 3000,
output_layer: str = "clahe",
crd: Optional[Tuple[int, int, int, int]] = None,
overwrite: bool = False,
) -> SpatialData:
"""
Expand All @@ -34,14 +35,15 @@ def enhance_contrast(
The default value is 3.5.
chunks : str | tuple[int, int] | int, optional
The size of the chunks used during dask image processing.
Larger chunks may lead to increased memory usage but faster processing.
The default value is 10000.
depth : Tuple[int, int] | Dict[ int, int ] | int, optional
The overlapping depth used in dask array map_overlap operation.
The default value is 3000.
output_layer : str, optional
The name of the image layer where the enhanced image will be stored.
The default value is "clahe".
crd : Optional[Tuple[int, int, int, int]], default=None
The coordinates specifying the region of the image to be processed. Defines the bounds (x_min, x_max, y_min, y_max).
overwrite: bool
If True overwrites the element if it already exists.
Expand Down Expand Up @@ -71,6 +73,7 @@ def _apply_clahe(image: NDArray, contrast_clip: float = 3.5) -> NDArray:
channel=None, # channel==None -> apply apply_clahe to each layer seperately
fn_kwargs={"contrast_clip": contrast_clip},
depth=depth,
crd=crd,
overwrite=overwrite,
)

Expand Down
42 changes: 23 additions & 19 deletions src/napari_sparrow/image/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,54 +6,58 @@
from spatial_image import SpatialImage
from spatialdata.transformations import get_transformation

# TODO: check if we want to keep the functions here in this generally named ïmage.py, there are mostly related to the translation/transformation, should the filename reflect this?
# TODO: rename _get_translation to _get_image_translation for consistency with _get_image_boundary? or use _get_translation() and _get_boundary() because we're in the image package anyway, so the image is understood already?

# FIXME: we're type hinting crd as Tuple but often use a list for bounding box rectangles


def _substract_translation_crd(
spatialimage: SpatialImage, crd=Tuple[int, int, int, int]
spatial_image: SpatialImage, crd=Tuple[int, int, int, int]
) -> Optional[Tuple[int, int, int, int]]:
tx, ty = _get_translation(spatialimage)
tx, ty = _get_translation(spatial_image)

_crd = crd
crd = [
int(max(0, crd[0] - tx)),
max(0, int(min(spatialimage.sizes["x"], crd[1] - tx))),
max(0, int(min(spatial_image.sizes["x"], crd[1] - tx))),
int(max(0, crd[2] - ty)),
max(0, int(min(spatialimage.sizes["y"], crd[3] - ty))),
max(0, int(min(spatial_image.sizes["y"], crd[3] - ty))),
]

if crd[1] - crd[0] <= 0 or crd[3] - crd[2] <= 0:
warnings.warn(
f"Crop param {_crd} after correction for possible translation on "
f"spatialimage object '{spatialimage.name}' is "
f"SpatialImage object '{spatial_image.name}' is "
f"'{crd}. Falling back to setting crd to 'None'."
)
crd = None

return crd


def _get_image_boundary(spatial_image):
def _get_boundary(spatial_image: SpatialImage) -> Tuple[int, int, int, int]:
tx, ty = _get_translation(spatial_image)
width = spatial_image.sizes["x"]
height = spatial_image.sizes["y"]
return [int(tx), int(tx + width), int(ty), int(ty + height)]
return (int(tx), int(tx + width), int(ty), int(ty + height))


def _get_translation(spatial_image):
def _get_translation(spatial_image: SpatialImage) -> Tuple[float, float]:
transform_matrix = get_transformation(spatial_image).to_affine_matrix(
input_axes=("x", "y"), output_axes=("x", "y")
)

# Extract translation components from transformation matrix
tx = transform_matrix[:, -1][
0
] # FIXME: why not directly access the correct element in the matrix?
ty = transform_matrix[:, -1][1]
return tx, ty
if (
transform_matrix[0, 0] == 1.0
and transform_matrix[0, 1] == 0.0
and transform_matrix[1, 0] == 0.0
and transform_matrix[1, 1] == 1.0
and transform_matrix[2, 0] == 0.0
and transform_matrix[2, 1] == 0.0
and transform_matrix[2, 2] == 1.0
):
return transform_matrix[0, 2], transform_matrix[1, 2]
else:
raise ValueError(
f"The provided transform matrix '{transform_matrix}' associated with the SpatialImage "
f"element with name '{spatial_image.name}' represents more than just a translation, which is not currently supported."
)


# FIXME: the type "SpatialImage" is probably too restrictive here, see type of sdata[layer], which is more general
Expand Down
6 changes: 5 additions & 1 deletion src/napari_sparrow/image/_minmax.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, List, Optional
from typing import Iterable, List, Optional, Tuple

from dask.array import Array
from dask_image.ndfilters import maximum_filter, minimum_filter
Expand All @@ -12,6 +12,7 @@ def min_max_filtering(
img_layer: Optional[str] = None,
size_min_max_filter: int | List[int] = 85,
output_layer="min_max_filtered",
crd: Optional[Tuple[int, int, int, int]] = None,
overwrite: bool = False,
) -> SpatialData:
"""
Expand All @@ -30,6 +31,8 @@ def min_max_filtering(
must match the number of channels. Defaults to 85.
output_layer : str, optional
The name of the output layer. Defaults to "min_max_filtered".
crd : Optional[Tuple[int, int, int, int]], default=None
The coordinates specifying the region of the image to be processed. Defines the bounds (x_min, x_max, y_min, y_max).
overwrite: bool
If True overwrites the element if it already exists.
Expand Down Expand Up @@ -76,6 +79,7 @@ def _apply_min_max_filtering(image: Array, size_min_max_filter: int = 85) -> Arr
chunks=None,
channel=None,
fn_kwargs={"size_min_max_filter": size_min_max_filter},
crd=crd,
overwrite=overwrite,
)

Expand Down
11 changes: 6 additions & 5 deletions src/napari_sparrow/image/_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from spatialdata import SpatialData
from spatialdata.transformations import Translation, set_transformation

from napari_sparrow.image._image import (_get_translation,
_substract_translation_crd)
from napari_sparrow.image._image import _get_translation, _substract_translation_crd
from napari_sparrow.shape._shape import _mask_image_to_polygons
from napari_sparrow.utils.pylogger import get_pylogger

Expand All @@ -34,7 +33,7 @@ def _cellpose(
channels: List[int] = [0, 0],
device: str = "cpu",
) -> NDArray:
gpu = torch.cuda.is_available()
gpu = torch.cuda.is_available() or torch.backends.mps.is_available()
model = models.Cellpose(gpu=gpu, model_type=model_type, device=torch.device(device))
masks, _, _, _ = model.eval(
img,
Expand Down Expand Up @@ -277,8 +276,10 @@ def _segment(
# if trim==True --> use squidpy's way of handling neighbouring blocks
if trim:
from dask_image.ndmeasure._utils._label import (
connected_components_delayed, label_adjacency_graph,
relabel_blocks)
connected_components_delayed,
label_adjacency_graph,
relabel_blocks,
)

# max because labels are not continuous (and won't be continuous)
label_groups = label_adjacency_graph(x_labels, None, x_labels.max())
Expand Down
4 changes: 2 additions & 2 deletions src/napari_sparrow/image/_transcripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from spatialdata import SpatialData
from spatialdata.transformations import Translation, set_transformation

from napari_sparrow.image._image import _get_image_boundary
from napari_sparrow.image._image import _get_boundary
from napari_sparrow.utils.pylogger import get_pylogger

log = get_pylogger(__name__)
Expand Down Expand Up @@ -82,7 +82,7 @@ def transcript_density(
# get image boundary from last image layer if img_layer is None
if img_layer is None:
img_layer = [*sdata.images][-1]
img_boundary = _get_image_boundary(sdata[img_layer])
img_boundary = _get_boundary(sdata[img_layer])

# if crd is None, get boundary from image at img_layer if given,
if crd is None:
Expand Down
8 changes: 7 additions & 1 deletion src/napari_sparrow/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ def clean(self, sdata: SpatialData) -> SpatialData:
sdata = nas.im.min_max_filtering(
sdata=sdata,
img_layer=self.cleaned_image_name,
crd=self.cfg.clean.crop_param
if self.cfg.clean.crop_param is not None
else None,
size_min_max_filter=list(self.cfg.clean.size_min_max_filter)
if isinstance(self.cfg.clean.size_min_max_filter, ListConfig)
else self.cfg.clean.size_min_max_filter,
Expand Down Expand Up @@ -182,6 +185,9 @@ def clean(self, sdata: SpatialData) -> SpatialData:
sdata = nas.im.enhance_contrast(
sdata=sdata,
img_layer=self.cleaned_image_name,
crd=self.cfg.clean.crop_param
if self.cfg.clean.crop_param is not None
else None,
contrast_clip=list(self.cfg.clean.contrast_clip)
if isinstance(self.cfg.clean.contrast_clip, ListConfig)
else self.cfg.clean.contrast_clip,
Expand Down Expand Up @@ -222,7 +228,7 @@ def segment(self, sdata: SpatialData) -> SpatialData:
2 * self.cfg.segmentation.diameter,
)

self.shapes_layer_name=self.cfg.segmentation.output_shapes_layer
self.shapes_layer_name = self.cfg.segmentation.output_shapes_layer

# Perform segmentation
sdata = nas.im.segment(
Expand Down
4 changes: 2 additions & 2 deletions src/napari_sparrow/plot/_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np

from napari_sparrow.plot._plot import plot_shapes
from napari_sparrow.image._image import _get_image_boundary
from napari_sparrow.image._image import _get_boundary


def score_genes(
Expand Down Expand Up @@ -62,7 +62,7 @@ def score_genes(
si = sdata.images[img_layer]

if crd is None:
crd = _get_image_boundary(si)
crd = _get_boundary(si)

# Custom colormap:
colors = np.concatenate(
Expand Down
4 changes: 2 additions & 2 deletions src/napari_sparrow/plot/_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from napari_sparrow.image._image import (
_apply_transform,
_get_image_boundary,
_get_boundary,
_unapply_transform,
)
from napari_sparrow.shape import intersect_rectangles
Expand Down Expand Up @@ -281,7 +281,7 @@ def _plot_shapes( # FIXME: rename, this does not always plot a shapes layer any
# Update coords
si, x_coords_orig, y_coords_orig = _apply_transform(si)

image_boundary = _get_image_boundary(si)
image_boundary = _get_boundary(si)

if crd is not None:
_crd = crd
Expand Down
Loading

0 comments on commit a6b3689

Please sign in to comment.