Skip to content

Commit

Permalink
Merge pull request #28 from HelmholtzAI-Consultants-Munich:multi-anno…
Browse files Browse the repository at this point in the history
…tation

Multi annotation
  • Loading branch information
francesco-campi authored Dec 3, 2024
2 parents 7f7518a + 2b661cb commit 4df8bf1
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 16 deletions.
8 changes: 5 additions & 3 deletions napari_organoid_counter/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,22 @@ def write_to_json(name, data):
with open(name, 'w') as outfile:
json.dump(data, outfile)

def get_bboxes_as_dict(bboxes, bbox_ids, scores, scales):
""" Write all data, boxes, ids and scores and scale, to a dict so we can later save as a json """
def get_bboxes_as_dict(bboxes, bbox_ids, scores, scales, labels):
""" Write all data, boxes, ids and scores, scale and class label, to a dict so we can later save as a json """
data_json = {}
for idx, bbox in enumerate(bboxes):
x1, y1 = bbox[0]
x2, y2 = bbox[2]

data_json.update({str(bbox_ids[idx]): {'box_id': str(bbox_ids[idx]),
'x1': str(x1),
'x2': str(x2),
'y1': str(y1),
'y2': str(y2),
'confidence': str(scores[idx]),
'scale_x': str(scales[0]),
'scale_y': str(scales[1])
'scale_y': str(scales[1]),
'class': labels[idx]
}
})
return data_json
Expand Down
176 changes: 163 additions & 13 deletions napari_organoid_counter/_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
from skimage.io import imsave
from datetime import datetime

import napari

from napari import layers
from napari.utils.notifications import show_info
from napari.utils.notifications import show_info, show_error, show_warning

import numpy as np

from qtpy.QtCore import Qt
from qtpy.QtWidgets import QWidget, QVBoxLayout, QApplication, QDialog, QFileDialog, QGroupBox, QHBoxLayout, QLabel, QComboBox, QPushButton, QLineEdit, QProgressBar, QSlider
Expand Down Expand Up @@ -95,6 +99,10 @@ def __init__(self,
self.stored_confidences = {}
self.stored_diameters = {}

# Initialize multi_annotation_mode to False by default
self.multi_annotation_mode = False
# self.single_annotation_mode = True # Initially, it's single annotation mode

# setup gui
self.setLayout(QVBoxLayout())
self.layout().addWidget(self._setup_input_widget())
Expand All @@ -112,9 +120,76 @@ def __init__(self,
self.viewer.layers.events.inserted.connect(self._added_layer)
self.viewer.layers.events.removed.connect(self._removed_layer)
self.viewer.layers.selection.events.changed.connect(self._sel_layer_changed)

# setup flags used for changing slider and text of min diameter and confidence threshold
self.diameter_slider_changed = False
self.confidence_slider_changed = False
self.confidence_slider_changed = False

# Key binding to change the edge_color of the bounding boxes to green
@self.viewer.bind_key('g')
def change_edge_color_to_green(viewer: napari.Viewer):
if not self.multi_annotation_mode: # Check if single-annotation mode is active
show_error("Cannot change edge color. Change to multi-annotation mode to enable this feature.")
return
if self.cur_shapes_layer is not None: # Ensure shapes layer exists
selected_shapes = self.cur_shapes_layer.selected_data # Retrieves indices of shapes currently selected, returns a set
if len(selected_shapes) > 0:
# Modify the edge color only for the selected shapes
current_edge_colors = self.cur_shapes_layer.edge_color
for idx in selected_shapes:
# Save original color
# if idx not in self.original_colors:
# self.original_colors[idx] = current_edge_colors[idx].copy()
# Update to the new color
current_edge_colors[idx] = settings.COLOR_CLASS_1
self.cur_shapes_layer.edge_color = current_edge_colors # Apply the changes
show_info(f"Changed edge color of shapes {list(selected_shapes)} to green.")
else:
show_warning("No shapes selected to change edge color.")

# Key binding to change the edge_color of the bounding boxes to blue
@self.viewer.bind_key('h')
def change_edge_color_to_blue(viewer: napari.Viewer):
if not self.multi_annotation_mode: # Check if single-annotation mode is active
show_error("Cannot change edge color. Change to multi-annotation mode to enable this feature.")
return
if self.cur_shapes_layer is not None: # Ensure shapes layer exists
selected_shapes = self.cur_shapes_layer.selected_data
if len(selected_shapes) > 0:
# Modify the edge color only for the selected shapes
current_edge_colors = self.cur_shapes_layer.edge_color
for idx in selected_shapes:
# Save original color
# if idx not in self.original_colors:
# self.original_colors[idx] = current_edge_colors[idx].copy()
# Update to the new color
current_edge_colors[idx] = settings.COLOR_CLASS_2
self.cur_shapes_layer.edge_color = current_edge_colors # Apply the changes
show_info(f"Changed edge color of {list(selected_shapes)} to blue.")
else:
show_warning("No shapes selected to change edge color.")

# Key binding to reset the edge_color of selected bounding boxes to the original magenta color
@self.viewer.bind_key('m')
def change_to_original_color(viewer: napari.Viewer):
if not self.multi_annotation_mode: # Check if single-annotation mode is active
show_info("Cannot change edge color. Change to multi-annotation mode to enable this feature.")
return
if self.cur_shapes_layer is not None: # Ensure shapes layer exists
selected_shapes = self.cur_shapes_layer.selected_data
if len(selected_shapes) > 0:
current_edge_colors = self.cur_shapes_layer.edge_color
# Modify the edge color only for the selected shapes
current_edge_colors = self.cur_shapes_layer.edge_color
for idx in selected_shapes:
# if idx in self.original_colors:
# Revert to the original color
current_edge_colors[idx] = settings.COLOR_DEFAULT
self.cur_shapes_layer.edge_color = current_edge_colors # Apply the changes
show_info(f"Reset edge color of {list(selected_shapes)} to magenta.")
else:
show_warning("No shapes selected to reset edge color.")


def handle_progress(self, blocknum, blocksize, totalsize):
""" When the model is being downloaded, this method is called and th progress of the download
Expand Down Expand Up @@ -212,12 +287,13 @@ def _update_vis_bboxes(self, bboxes, scores, box_ids, labels_layer_name):
face_color='transparent',
properties = properties,
text = text_params,
edge_color='magenta',
edge_color=settings.COLOR_DEFAULT,
shape_type='rectangle',
edge_width=12) # warning generated here

# set current_edge_width so edge width is the same when users annotate - doesnt' fix new preds being added!
self.viewer.layers[labels_layer_name].current_edge_width = 12


def _on_preprocess_click(self):
""" Is called whenever preprocess button is clicked """
Expand Down Expand Up @@ -404,6 +480,17 @@ def _on_screenshot_click(self):
name,_ = fd.getSaveFileName(self, 'Save File', potential_name, 'Image files (*.png);;(*.tiff)') #, 'CSV Files (*.csv)')
if name: imsave(name, screenshot)

def on_annotation_mode_changed(self, index):
"""Callback for dropdown selection."""
if index == 0: # Single Annotation
self.multi_annotation_mode = False
# self.single_annotation_mode = True
show_info("Switched to Single Annotation mode.")
elif index == 1: # Multi Annotation
self.multi_annotation_mode = True
# self.single_annotation_mode = False
show_info("Switched to Multi Annotation mode.")

def _on_save_csv_click(self):
""" Is called whenever Save features button is clicked """
bboxes = self.viewer.layers[self.save_layer_name].data
Expand All @@ -421,16 +508,57 @@ def _on_save_json_click(self):
""" Is called whenever Save boxes button is clicked """
bboxes = self.viewer.layers[self.save_layer_name].data
#scores = #add
if not bboxes: show_info('No organoids detected! Please run auto organoid counter or run algorithm first and try again!')
if not bboxes:
show_info('No organoids detected! Please run auto organoid counter or run algorithm first and try again!')
return

# Check for multi-annotation mode
if self.multi_annotation_mode:

# Get the edge colors for all bounding boxes
edge_colors = self.cur_shapes_layer.edge_color
labels = []

# Check if all bounding boxes have their edge color set (not green or blue)
green = np.array(settings.COLOR_CLASS_1)
blue = np.array(settings.COLOR_CLASS_2)

all_colored = True
for edge_color in edge_colors:
# Compare the colors with a tolerance using np.allclose to account for floating-point errors
if not (np.allclose(edge_color[:3], green[:3]) or np.allclose(edge_color[:3], blue[:3])):
all_colored = False
break

if not all_colored:
show_error('Please change the color of all bounding boxes before saving.')
return

# Assign organoid label based on edge_color
for edge_color in edge_colors:
if np.allclose(edge_color[:3], green[:3]):
labels.append(0) # Label for green
elif np.allclose(edge_color[:3], blue[:3]):
labels.append(1) # Label for blue
else:
raise ValueError(f"Unexpected edge color {edge_color[:3]} encountered.")

#elif self.single_annotation_mode:
else:
data_json = utils.get_bboxes_as_dict(bboxes,
self.viewer.layers[self.save_layer_name].properties['box_id'],
self.viewer.layers[self.save_layer_name].properties['scores'],
self.viewer.layers[self.save_layer_name].scale)
# write bbox coordinates to json
fd = QFileDialog()
name,_ = fd.getSaveFileName(self, 'Save File', self.save_layer_name, 'JSON files (*.json)')#, 'CSV Files (*.csv)')
if name: utils.write_to_json(name, data_json)
# Single annotation mode: all bounding boxes get a default label
labels = [0] * len(bboxes) # Default label for single annotation mode

data_json = utils.get_bboxes_as_dict(bboxes,
self.viewer.layers[self.save_layer_name].properties['box_id'],
self.viewer.layers[self.save_layer_name].properties['scores'],
self.viewer.layers[self.save_layer_name].scale,
labels=labels)


# write bbox coordinates to json
fd = QFileDialog()
name,_ = fd.getSaveFileName(self, 'Save File', self.save_layer_name, 'JSON files (*.json)')#, 'CSV Files (*.csv)')
if name: utils.write_to_json(name, data_json)

def _update_added_image(self, added_items):
"""
Expand Down Expand Up @@ -471,7 +599,8 @@ def _update_added_shapes(self, added_items):
self.organoiDL.update_bboxes_scores(self.cur_shapes_name,
self.cur_shapes_layer.data,
self.cur_shapes_layer.properties['scores'],
self.cur_shapes_layer.properties['box_id'])
self.cur_shapes_layer.properties['box_id']
)
self.cur_shapes_layer.events.data.connect(self.shapes_event_handler)

def _update_remove_shapes(self, removed_layers):
Expand Down Expand Up @@ -507,6 +636,7 @@ def shapes_event_handler(self, event):
new_ids[-1] = self.organoiDL.next_id[self.cur_shapes_name]
new_scores = self.viewer.layers[self.cur_shapes_name].properties['scores']
new_scores[-1] = 1

# set new properties to shapes layer
self.viewer.layers[self.cur_shapes_name].properties ={'box_id': new_ids,'scores': new_scores}
# refresh text displayed
Expand All @@ -530,6 +660,7 @@ def _setup_input_widget(self):
window_sizes_box = self._setup_window_sizes_box()
downsampling_box = self._setup_downsampling_box()
run_box = self._setup_run_box()
annotation_mode_box = self._setup_annotation_mode_box() # Annotation mode dropdown to select single or multi-annotation
self._setup_progress_box()

# and add all these to the layout
Expand All @@ -540,6 +671,7 @@ def _setup_input_widget(self):
vbox.addLayout(window_sizes_box)
vbox.addLayout(downsampling_box)
vbox.addLayout(run_box)
vbox.addLayout(annotation_mode_box) # Add the annotation dropdown
vbox.addWidget(self.progress_box)
input_widget.setLayout(vbox)
return input_widget
Expand Down Expand Up @@ -682,6 +814,24 @@ def _setup_run_box(self):
hbox.addWidget(run_btn)
hbox.addStretch(1)
return hbox

def _setup_annotation_mode_box(self):
"""
Sets up the GUI part where the annotation mode is selected.
"""
hbox = QHBoxLayout()

# Label
annotation_mode_label = QLabel("Annotation Mode:", self)
hbox.addWidget(annotation_mode_label)

# Dropdown
self.annotation_mode_dropdown = QComboBox()
self.annotation_mode_dropdown.addItems(["Single Annotation", "Multi Annotation"])
self.annotation_mode_dropdown.currentIndexChanged.connect(self.on_annotation_mode_changed)
hbox.addWidget(self.annotation_mode_dropdown)

return hbox

def _setup_progress_box(self):
"""
Expand Down
11 changes: 11 additions & 0 deletions napari_organoid_counter/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,18 @@ def init():
"rtmdet": {"source": "https://zenodo.org/records/11388549/files/rtmdet_l_organoid.py",
"destination": ".mim/configs/rtmdet/rtmdet_l_organoid.py"
}

}

# Add color definitions
global COLOR_CLASS_1
COLOR_CLASS_1 = [85 / 255, 1.0, 0, 1.0] # Green

global COLOR_CLASS_2
COLOR_CLASS_2 = [0, 29 / 255, 1.0, 1.0] # Blue

global COLOR_DEFAULT
COLOR_DEFAULT = [1., 0, 1., 1.] # Magenta



0 comments on commit 4df8bf1

Please sign in to comment.