Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enh] Handle segmenting image layers that have non-1 layer.scale #804

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion micro_sam/sam_annotator/_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,15 @@ def _update_image(self, segmentation_result=None):

# Reset all layers.
self._viewer.layers["current_object"].data = np.zeros(self._shape, dtype="uint32")
self._viewer.layers["current_object"].scale = state.image_scale
self._viewer.layers["auto_segmentation"].data = np.zeros(self._shape, dtype="uint32")
self._viewer.layers["auto_segmentation"].scale = state.image_scale
if segmentation_result is None or segmentation_result is False:
self._viewer.layers["committed_objects"].data = np.zeros(self._shape, dtype="uint32")
else:
assert segmentation_result.shape == self._shape
self._viewer.layers["committed_objects"].data = segmentation_result

self._viewer.layers["committed_objects"].scale = state.image_scale
self._viewer.layers["point_prompts"].scale = state.image_scale
self._viewer.layers["prompts"].scale = state.image_scale
vutil.clear_annotations(self._viewer, clear_segmentations=False)
2 changes: 2 additions & 0 deletions micro_sam/sam_annotator/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class AnnotatorState(metaclass=Singleton):
image_embeddings: Optional[util.ImageEmbeddings] = None
predictor: Optional[SamPredictor] = None
image_shape: Optional[Tuple[int, int]] = None
image_scale: Optional[Tuple[float, ...]] = None
embedding_path: Optional[str] = None
data_signature: Optional[str] = None

Expand Down Expand Up @@ -198,6 +199,7 @@ def reset_state(self):
self.image_embeddings = None
self.predictor = None
self.image_shape = None
self.image_scale = None
self.embedding_path = None
self.amg = None
self.amg_state = None
Expand Down
17 changes: 12 additions & 5 deletions micro_sam/sam_annotator/_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def clear_volume(viewer: "napari.viewer.Viewer", all_slices: bool = True) -> Non
if all_slices:
vutil.clear_annotations(viewer)
else:
i = int(viewer.cursor.position[0])
i = int(viewer.dims.point[0])
vutil.clear_annotations_slice(viewer, i=i)


Expand All @@ -341,7 +341,7 @@ def clear_track(viewer: "napari.viewer.Viewer", all_frames: bool = True) -> None
_reset_tracking_state(viewer)
vutil.clear_annotations(viewer)
else:
i = int(viewer.cursor.position[0])
i = int(viewer.dims.point[0])
vutil.clear_annotations_slice(viewer, i=i)


Expand Down Expand Up @@ -736,7 +736,9 @@ def segment_slice(viewer: "napari.viewer.Viewer") -> None:
return None

shape = viewer.layers["current_object"].data.shape[1:]
position = viewer.cursor.position

position_world = viewer.dims.point
position = viewer.layers["point_prompts"].world_to_data(position_world)
z = int(position[0])

point_prompts = vutil.point_layer_to_prompts(viewer.layers["point_prompts"], z)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't we also need to scale the points in here if we have a scale?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, see my comment below. in point_layer_to_prompts everything is now in the image layer data coordinates, because the annotation layers have the same scale as the image layer. We had to transform viewer.dims.point because that is in world (scaled or canvas) coordinates.

Expand Down Expand Up @@ -775,7 +777,7 @@ def segment_frame(viewer: "napari.viewer.Viewer") -> None:
return None
state = AnnotatorState()
shape = state.image_shape[1:]
position = viewer.cursor.position
position = viewer.dims.point
t = int(position[0])

point_prompts = vutil.point_layer_to_prompts(viewer.layers["point_prompts"], i=t, track_id=state.current_track_id)
Expand Down Expand Up @@ -868,7 +870,9 @@ def __init__(self, parent=None):
def _initialize_image(self):
state = AnnotatorState()
image_shape = self.image_selection.get_value().data.shape
image_scale = tuple(self.image_selection.get_value().scale)
state.image_shape = image_shape
state.image_scale = image_scale

def _create_image_section(self):
image_section = QtWidgets.QVBoxLayout()
Expand Down Expand Up @@ -1083,6 +1087,9 @@ def __call__(self, skip_validate=False):
ndim = image.data.ndim
state.image_shape = image.data.shape

# Set layer scale
state.image_scale = tuple(image.scale)

# Process tile_shape and halo, set other data.
tile_shape, halo = _process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y)
save_path = None if self.embeddings_save_path == "" else self.embeddings_save_path
Expand Down Expand Up @@ -1655,7 +1662,7 @@ def __call__(self):
if self.volumetric and self.apply_to_volume:
worker = self._run_segmentation_3d(kwargs)
elif self.volumetric and not self.apply_to_volume:
i = int(self._viewer.cursor.position[0])
i = int(self._viewer.dims.point[0])
worker = self._run_segmentation_2d(kwargs, i=i)
else:
worker = self._run_segmentation_2d(kwargs)
Expand Down
4 changes: 2 additions & 2 deletions micro_sam/sam_annotator/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def point_layer_to_prompts(
this_points, this_labels = points, labels
else:
assert points.shape[1] == 3, f"{points.shape}"
mask = points[:, 0] == i
mask = np.round(points[:, 0]) == i
this_points = points[mask][:, 1:]
this_labels = labels[mask]
assert len(this_points) == len(this_labels)
Expand Down Expand Up @@ -355,7 +355,7 @@ def segment_slices_with_prompts(
image_shape = shape[1:]
seg = np.zeros(shape, dtype="uint32")

z_values = point_prompts.data[:, 0]
z_values = np.round(point_prompts.data[:, 0])
z_values_boxes = np.concatenate([box[:1, 0] for box in box_prompts.data]) if box_prompts.data else\
np.zeros(0, dtype="int")

Expand Down
Loading