Skip to content

Commit

Permalink
Merge pull request #710 from bnmajor/roi-image
Browse files Browse the repository at this point in the history
Roi image
  • Loading branch information
bnmajor authored Dec 14, 2023
2 parents 2c5e5e4 + 552ba12 commit 31df05d
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 85 deletions.
208 changes: 130 additions & 78 deletions examples/integrations/itk/SelectROI.ipynb

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion itkwidgets/_initialization_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def parse_input_data(init_data_kwargs):
return inputs


def build_init_data(input_data):
def build_init_data(input_data, stores):
result= None
for input_type in DATA_OPTIONS:
data = input_data.pop(input_type, None)
Expand All @@ -83,12 +83,15 @@ def build_init_data(input_data):
if render_type is RenderType.IMAGE:
if input_type == 'label_image':
result = _get_viewer_image(data, label=True)
stores['LabelImage'] = result
render_type = RenderType.LABELIMAGE
elif input_type == 'fixed_image':
result = _get_viewer_image(data)
stores['Fixed'] = result
render_type = RenderType.FIXEDIMAGE
else:
result = _get_viewer_image(data, label=False)
stores['Image'] = result
elif render_type is RenderType.POINT_SET:
result = _get_viewer_point_set(data)
if result is None:
Expand Down
86 changes: 80 additions & 6 deletions itkwidgets/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from inspect import isawaitable
from typing import List, Union, Tuple
from IPython.display import display, HTML
from ngff_zarr import from_ngff_zarr, to_ngff_image, NgffImage
import uuid

from ._type_aliases import Gaussians, Style, Image, PointSet
Expand Down Expand Up @@ -136,9 +137,10 @@ def __init__(
self, ui_collapsed=True, rotate=False, ui="pydata-sphinx", **add_data_kwargs
):
"""Create a viewer."""
self.stores = {}
self.name = self.__str__()
input_data = parse_input_data(add_data_kwargs)
data = build_init_data(input_data)
data = build_init_data(input_data, self.stores)
if compare := input_data.get('compare'):
data['compare'] = compare
if ENVIRONMENT is not Env.HYPHA:
Expand Down Expand Up @@ -213,6 +215,7 @@ def set_image(self, image: Image, name: str = 'Image'):
render_type = _detect_render_type(image, 'image')
if render_type is RenderType.IMAGE:
image = _get_viewer_image(image, label=False)
self.stores[name] = image
if ENVIRONMENT is Env.HYPHA:
self.image = image
svc_name = f'{self.workspace}/itkwidgets-server:data-set'
Expand All @@ -225,8 +228,28 @@ def set_image(self, image: Image, name: str = 'Image'):
image = _get_viewer_point_set(image)
self.viewer_rpc.itk_viewer.setPointSets(image)
@fetch_value
async def get_image(self):
return await self.viewer_rpc.itk_viewer.getImage()
async def get_image(self, name: str = 'Image') -> NgffImage:
"""Get the full, highest resolution image.
:param name: Name of the loaded image data to use. 'Image', the
default, selects the first loaded image.
:type name: str
:return: image
:rtype: NgffImage
"""
if store := self.stores.get(name):
multiscales = from_ngff_zarr(store)
loaded_image = multiscales.images[0]
roi_data = loaded_image.data
return to_ngff_image(
roi_data,
dims=loaded_image.dims,
scale=loaded_image.scale,
name=name,
axes_units=loaded_image.axes_units
)
raise ValueError(f'No image data found for {name}.')

@fetch_value
def set_image_blend_mode(self, mode: str):
Expand Down Expand Up @@ -323,6 +346,36 @@ async def get_current_scale(self):
"""
return await self.viewer_rpc.itk_viewer.getLoadedScale()

@fetch_value
async def get_roi_image(self, scale: int = -1, name: str = 'Image') -> NgffImage:
"""Get the image for the current ROI.
:param scale: scale of the primary image to get the slices for the
current roi. -1, the default, uses the current scale.
:type scale: int
:param name: Name of the loaded image data to use. 'Image', the
default, selects the first loaded image.
:type name: str
:return: roi_image
:rtype: NgffImage
"""
roi_slices = await self.get_roi_slice(scale)
roi_region = await self.get_roi_region()
if store := self.stores.get(name):
multiscales = from_ngff_zarr(store)
loaded_image = multiscales.images[scale]
roi_data = loaded_image.data[roi_slices]
return to_ngff_image(
roi_data,
dims=loaded_image.dims,
scale=loaded_image.scale,
translation=roi_region[0],
name=name,
axes_units=loaded_image.axes_units
)
raise ValueError(f'No image data found for {name}.')

@fetch_value
async def get_roi_region(self):
"""Get the current region of interest in world / physical space.
Expand All @@ -339,7 +392,7 @@ async def get_roi_region(self):
return [{ 'x': x0, 'y': y0, 'z': z0 }, { 'x': x1, 'y': y1, 'z': z1 }]

@fetch_value
async def get_roi_slice(self, scale=-1):
async def get_roi_slice(self, scale: int = -1):
"""Get the current region of interest as Python slice objects for the
current resolution of the primary image. The result is in the order:
Expand Down Expand Up @@ -395,6 +448,7 @@ def set_label_image(self, label_image: Image):
render_type = _detect_render_type(label_image, 'image')
if render_type is RenderType.IMAGE:
label_image = _get_viewer_image(label_image, label=True)
self.stores['LabelImage'] = label_image
if ENVIRONMENT is Env.HYPHA:
self.label_image = label_image
svc_name = f"{self.workspace}/itkwidgets-server:data-set"
Expand All @@ -407,8 +461,24 @@ def set_label_image(self, label_image: Image):
label_image = _get_viewer_point_set(label_image)
self.viewer_rpc.itk_viewer.setPointSets(label_image)
@fetch_value
async def get_label_image(self):
return await self.viewer_rpc.itk_viewer.getLabelImage()
async def get_label_image(self) -> NgffImage:
"""Get the full, highest resolution label image.
:return: label_image
:rtype: NgffImage
"""
if store := self.stores.get('LabelImage'):
multiscales = from_ngff_zarr(store)
loaded_image = multiscales.images[0]
roi_data = loaded_image.data
return to_ngff_image(
roi_data,
dims=loaded_image.dims,
scale=loaded_image.scale,
name='LabelImage',
axes_units=loaded_image.axes_units
)
raise ValueError(f'No label image data found.')

@fetch_value
def set_label_image_blend(self, blend: float):
Expand Down Expand Up @@ -452,6 +522,10 @@ def set_layer_visibility(self, visible: bool, name: str):
async def get_layer_visibility(self, name: str):
return await self.viewer_rpc.itk_viewer.getLayerVisibility(name)

@fetch_value
def get_loaded_image_names(self):
return list(self.stores.keys())

@fetch_value
def add_point_set(self, pointSet: PointSet):
pointSet = _get_viewer_point_set(pointSet)
Expand Down

0 comments on commit 31df05d

Please sign in to comment.