Skip to content

Commit

Permalink
Merge pull request #484 from bnmajor/bug-fixes
Browse files Browse the repository at this point in the history
BUG: Incorrectly resolved conflicts created series of bugs
  • Loading branch information
thewtex authored Jul 29, 2022
2 parents 82bf597 + 5bf71ba commit faa5410
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 108 deletions.
8 changes: 8 additions & 0 deletions itkwidgets/_initialization_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,11 @@ def init_params_dict(itk_viewer):
'y_slice': itk_viewer.setYSlice,
'z_slice': itk_viewer.setZSlice,
}

def init_key_aliases():
return {
'data': 'image',
'image': 'image',
'label_image': 'labelImage',
'point_sets': 'pointSets',
}
110 changes: 14 additions & 96 deletions itkwidgets/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,133 +9,51 @@
from .xarray import HAVE_XARRAY, xarray_data_array_to_numpy, xarray_data_set_to_numpy
from ..render_types import RenderType

_image_count = 1

async def _set_viewer_image(itk_viewer, image, name=None, is_label=False):
global _image_count
if isinstance(image, itkwasm.Image):
if not name:
name = image.name
if not name:
name = f"image {_image_count}"
_image_count += 1
if is_label:
await itk_viewer.setLabelImage(image)
else:
await itk_viewer.setImage(image, name)
elif isinstance(image, np.ndarray):
if not name:
name = f"image {_image_count}"
_image_count += 1
if is_label:
await itk_viewer.setLabelImage(image)
else:
await itk_viewer.setImage(image, name)
elif isinstance(image, zarr.Group):
if not name:
name = f"image {_image_count}"
_image_count += 1
if is_label:
await itk_viewer.setLabelImage(image)
else:
await itk_viewer.setImage(image, name)
elif HAVE_ITK:
async def _get_viewer_image(image):
if HAVE_ITK:
import itk
if isinstance(image, itk.Image):
wasm_image = itk_image_to_wasm_image(image)
name = image.GetObjectName()
if not name:
name = f"image {_image_count}"
_image_count += 1
if is_label:
await itk_viewer.setLabelImage(wasm_image)
else:
await itk_viewer.setImage(wasm_image, name)
return itk_image_to_wasm_image(image)
if HAVE_VTK:
import vtk
if isinstance(image, vtk.vtkImageData):
ndarray = vtk_image_to_ndarray(image)
if not name:
name = f"image {_image_count}"
_image_count += 1
if is_label:
await itk_viewer.setLabelImage(ndarray)
else:
await itk_viewer.setImage(ndarray, name)
return vtk_image_to_ndarray(image)
if HAVE_DASK:
import dask
if isinstance(image, dask.array.core.Array):
ndarray = dask_array_to_ndarray(image)
name = image.name
if not name:
name = f"image {_image_count}"
_image_count += 1
if is_label:
await itk_viewer.setLabelImage(ndarray)
else:
await itk_viewer.setImage(ndarray, name)
return dask_array_to_ndarray(image)
if HAVE_TORCH:
import torch
if isinstance(image, torch.Tensor):
if not name:
name = f"image {_image_count}"
_image_count += 1
if is_label:
await itk_viewer.setLabelImage(image.numpy())
else:
await itk_viewer.setImage(image.numpy(), name)
return image.numpy()
if HAVE_XARRAY:
import xarray
if isinstance(image, xarray.DataArray):
ndarray = xarray_data_array_to_numpy(image)
name = image.name
if not name:
name = f"image {_image_count}"
_image_count += 1
if is_label:
await itk_viewer.setLabelImage(ndarray)
else:
await itk_viewer.setImage(ndarray, name)
return xarray_data_array_to_numpy(image)
if isinstance(image, xarray.Dataset):
ndarray = xarray_data_set_to_numpy(image)
if not name:
name = f"image {_image_count}"
_image_count += 1
if is_label:
await itk_viewer.setLabelImage(ndarray)
else:
await itk_viewer.setImage(ndarray, name)
return xarray_data_set_to_numpy(image)


async def _set_viewer_point_sets(itk_viewer, point_sets):
if isinstance(point_sets, itkwasm.PointSet):
await itk_viewer.setPointSets(point_sets)
elif isinstance(point_sets, np.ndarray):
await itk_viewer.setPointSets(point_sets)
elif isinstance(point_sets, zarr.Group):
await itk_viewer.setPointSets(point_sets)
async def _get_viewer_point_sets(itk_viewer, point_sets):
if HAVE_VTK:
import vtk
if isinstance(point_sets, vtk.vtkPolyData):
vtkjs_polydata = vtk_polydata_to_vtkjs(point_sets)
await itk_viewer.setPointSets(vtkjs_polydata)
return vtk_polydata_to_vtkjs(point_sets)
if HAVE_DASK:
import dask
if isinstance(point_sets, dask.array.core.Array):
ndarray = dask_array_to_ndarray(point_sets)
await itk_viewer.setPointSets(ndarray)
return dask_array_to_ndarray(point_sets)
if HAVE_TORCH:
import torch
if isinstance(point_sets, torch.Tensor):
await itk_viewer.setPointSets(point_sets.numpy())
return point_sets.numpy()
if HAVE_XARRAY:
import xarray
if isinstance(point_sets, xarray.DataArray):
ndarray = xarray_data_array_to_numpy(point_sets)
await itk_viewer.setPointSets(ndarray)
return xarray_data_array_to_numpy(point_sets)
if isinstance(point_sets, xarray.Dataset):
ndarray = xarray_data_set_to_numpy(point_sets)
await itk_viewer.setPointSets(ndarray)
return xarray_data_set_to_numpy(point_sets)


def _detect_render_type(data, input_type) -> RenderType:
Expand Down
35 changes: 23 additions & 12 deletions itkwidgets/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import List

from ._type_aliases import Gaussians, Style, Image, Point_Sets
from ._initialization_params import init_params_dict
from .integrations import _detect_render_type, _set_viewer_image, _set_viewer_point_sets
from ._initialization_params import init_params_dict, init_key_aliases
from .integrations import _detect_render_type, _get_viewer_image, _get_viewer_point_sets
from .render_types import RenderType

__all__ = [
Expand All @@ -23,9 +23,11 @@ def __init__(
"""Create a viewer."""
self._init_viewer_kwargs = dict(ui_collapsed=ui_collapsed, rotate=rotate, ui=ui)
self._init_viewer_kwargs.update(**add_data_kwargs)
self.init_data = {}

def _get_input_data(self):
input_options = ["data", "image", "point_sets"]
inputs = []
for option in input_options:
data = self._init_viewer_kwargs.get(option, None)
if data is not None:
Expand Down Expand Up @@ -58,27 +60,32 @@ async def run(self, ctx):
else:
config = {}

data, input_type = self._get_input_data()
inputs = self._get_input_data()

init_data = None
if data is not None:
self.init_data.clear()
for (input_type, data) in inputs:
render_type = _detect_render_type(data, input_type)
key = init_key_aliases()[input_type]
if render_type is RenderType.IMAGE:
init_data = {"image": data}
result = await _get_viewer_image(data)
elif render_type is RenderType.POINT_SET:
init_data = {"pointSets": data}
result = await _get_viewer_point_sets(data)
if not result:
result = data
self.init_data[key] = result

itk_viewer = await api.createWindow(
name=f"itkwidgets viewer {_viewer_count}",
type="itk-vtk-viewer",
src="https://kitware.github.io/itk-vtk-viewer/app",
fullscreen=False,
data=init_data,
data=self.init_data,
# config should be a python data dictionary and can't be a string e.g. 'pydata-sphinx',
config=config,
)
_viewer_count += 1

self.set_default_ui_values(itk_viewer)
self.itk_viewer = itk_viewer

def set_default_ui_values(self, itk_viewer):
Expand Down Expand Up @@ -112,9 +119,11 @@ def set_background_color(self, bgColor: List[float]):
async def set_image(self, image: Image):
render_type = _detect_render_type(image, 'image')
if render_type is RenderType.IMAGE:
await _set_viewer_image(self.viewer_rpc.itk_viewer, image)
image = _get_viewer_image(image)
await self.viewer_rpc.itk_viewer.setImage(image)
elif render_type is RenderType.POINT_SET:
await _set_viewer_point_sets(self.viewer_rpc.itk_viewer, image)
image = _get_viewer_point_sets(image)
await self.viewer_rpc.itk_viewer.setPointSets(image)

def set_image_blend_mode(self, mode: str):
self.viewer_rpc.itk_viewer.setImageBlendMode(mode)
Expand Down Expand Up @@ -152,9 +161,11 @@ def set_image_volume_sample_distance(self, distance: float):
async def set_label_image(self, label_image: Image):
render_type = _detect_render_type(label_image, 'image')
if render_type is RenderType.IMAGE:
await _set_viewer_image(self.viewer_rpc.itk_viewer, label_image, is_label=True)
label_image = _get_viewer_image(label_image, is_label=True)
await self.viewer_rpc.itk_viewer.setImage(label_image)
elif render_type is RenderType.POINT_SET:
await _set_viewer_point_sets(self.viewer_rpc.itk_viewer, label_image)
label_image = _get_viewer_point_sets(label_image, is_label=True)
await self.viewer_rpc.itk_viewer.setPointSets(label_image)

def set_label_image_blend(self, blend: float):
self.viewer_rpc.itk_viewer.setLabelImageBlend(blend)
Expand Down

0 comments on commit faa5410

Please sign in to comment.