diff --git a/napari_organoid_counter/_utils.py b/napari_organoid_counter/_utils.py index 541fa84..be95840 100644 --- a/napari_organoid_counter/_utils.py +++ b/napari_organoid_counter/_utils.py @@ -62,12 +62,13 @@ def write_to_json(name, data): with open(name, 'w') as outfile: json.dump(data, outfile) -def get_bboxes_as_dict(bboxes, bbox_ids, scores, scales): - """ Write all data, boxes, ids and scores and scale, to a dict so we can later save as a json """ +def get_bboxes_as_dict(bboxes, bbox_ids, scores, scales, labels): + """ Write all data, boxes, ids and scores, scale and class label, to a dict so we can later save as a json """ data_json = {} for idx, bbox in enumerate(bboxes): x1, y1 = bbox[0] x2, y2 = bbox[2] + data_json.update({str(bbox_ids[idx]): {'box_id': str(bbox_ids[idx]), 'x1': str(x1), 'x2': str(x2), @@ -75,7 +76,8 @@ def get_bboxes_as_dict(bboxes, bbox_ids, scores, scales): 'y2': str(y2), 'confidence': str(scores[idx]), 'scale_x': str(scales[0]), - 'scale_y': str(scales[1]) + 'scale_y': str(scales[1]), + 'class': labels[idx] } }) return data_json diff --git a/napari_organoid_counter/_widget.py b/napari_organoid_counter/_widget.py index 37fe7d0..9d0b1c9 100644 --- a/napari_organoid_counter/_widget.py +++ b/napari_organoid_counter/_widget.py @@ -3,8 +3,12 @@ from skimage.io import imsave from datetime import datetime +import napari + from napari import layers -from napari.utils.notifications import show_info +from napari.utils.notifications import show_info, show_error, show_warning + +import numpy as np from qtpy.QtCore import Qt from qtpy.QtWidgets import QWidget, QVBoxLayout, QApplication, QDialog, QFileDialog, QGroupBox, QHBoxLayout, QLabel, QComboBox, QPushButton, QLineEdit, QProgressBar, QSlider @@ -95,6 +99,10 @@ def __init__(self, self.stored_confidences = {} self.stored_diameters = {} + # Initialize multi_annotation_mode to False by default + self.multi_annotation_mode = False + # self.single_annotation_mode = True # Initially, it's single annotation mode + # setup gui self.setLayout(QVBoxLayout()) self.layout().addWidget(self._setup_input_widget()) @@ -112,9 +120,76 @@ def __init__(self, self.viewer.layers.events.inserted.connect(self._added_layer) self.viewer.layers.events.removed.connect(self._removed_layer) self.viewer.layers.selection.events.changed.connect(self._sel_layer_changed) + # setup flags used for changing slider and text of min diameter and confidence threshold self.diameter_slider_changed = False - self.confidence_slider_changed = False + self.confidence_slider_changed = False + + # Key binding to change the edge_color of the bounding boxes to green + @self.viewer.bind_key('g') + def change_edge_color_to_green(viewer: napari.Viewer): + if not self.multi_annotation_mode: # Check if single-annotation mode is active + show_error("Cannot change edge color. Change to multi-annotation mode to enable this feature.") + return + if self.cur_shapes_layer is not None: # Ensure shapes layer exists + selected_shapes = self.cur_shapes_layer.selected_data # Retrieves indices of shapes currently selected, returns a set + if len(selected_shapes) > 0: + # Modify the edge color only for the selected shapes + current_edge_colors = self.cur_shapes_layer.edge_color + for idx in selected_shapes: + # Save original color + # if idx not in self.original_colors: + # self.original_colors[idx] = current_edge_colors[idx].copy() + # Update to the new color + current_edge_colors[idx] = settings.COLOR_CLASS_1 + self.cur_shapes_layer.edge_color = current_edge_colors # Apply the changes + show_info(f"Changed edge color of shapes {list(selected_shapes)} to green.") + else: + show_warning("No shapes selected to change edge color.") + + # Key binding to change the edge_color of the bounding boxes to blue + @self.viewer.bind_key('h') + def change_edge_color_to_blue(viewer: napari.Viewer): + if not self.multi_annotation_mode: # Check if single-annotation mode is active + show_error("Cannot change edge color. Change to multi-annotation mode to enable this feature.") + return + if self.cur_shapes_layer is not None: # Ensure shapes layer exists + selected_shapes = self.cur_shapes_layer.selected_data + if len(selected_shapes) > 0: + # Modify the edge color only for the selected shapes + current_edge_colors = self.cur_shapes_layer.edge_color + for idx in selected_shapes: + # Save original color + # if idx not in self.original_colors: + # self.original_colors[idx] = current_edge_colors[idx].copy() + # Update to the new color + current_edge_colors[idx] = settings.COLOR_CLASS_2 + self.cur_shapes_layer.edge_color = current_edge_colors # Apply the changes + show_info(f"Changed edge color of {list(selected_shapes)} to blue.") + else: + show_warning("No shapes selected to change edge color.") + + # Key binding to reset the edge_color of selected bounding boxes to the original magenta color + @self.viewer.bind_key('m') + def change_to_original_color(viewer: napari.Viewer): + if not self.multi_annotation_mode: # Check if single-annotation mode is active + show_info("Cannot change edge color. Change to multi-annotation mode to enable this feature.") + return + if self.cur_shapes_layer is not None: # Ensure shapes layer exists + selected_shapes = self.cur_shapes_layer.selected_data + if len(selected_shapes) > 0: + current_edge_colors = self.cur_shapes_layer.edge_color + # Modify the edge color only for the selected shapes + current_edge_colors = self.cur_shapes_layer.edge_color + for idx in selected_shapes: + # if idx in self.original_colors: + # Revert to the original color + current_edge_colors[idx] = settings.COLOR_DEFAULT + self.cur_shapes_layer.edge_color = current_edge_colors # Apply the changes + show_info(f"Reset edge color of {list(selected_shapes)} to magenta.") + else: + show_warning("No shapes selected to reset edge color.") + def handle_progress(self, blocknum, blocksize, totalsize): """ When the model is being downloaded, this method is called and th progress of the download @@ -212,12 +287,13 @@ def _update_vis_bboxes(self, bboxes, scores, box_ids, labels_layer_name): face_color='transparent', properties = properties, text = text_params, - edge_color='magenta', + edge_color=settings.COLOR_DEFAULT, shape_type='rectangle', edge_width=12) # warning generated here # set current_edge_width so edge width is the same when users annotate - doesnt' fix new preds being added! self.viewer.layers[labels_layer_name].current_edge_width = 12 + def _on_preprocess_click(self): """ Is called whenever preprocess button is clicked """ @@ -404,6 +480,17 @@ def _on_screenshot_click(self): name,_ = fd.getSaveFileName(self, 'Save File', potential_name, 'Image files (*.png);;(*.tiff)') #, 'CSV Files (*.csv)') if name: imsave(name, screenshot) + def on_annotation_mode_changed(self, index): + """Callback for dropdown selection.""" + if index == 0: # Single Annotation + self.multi_annotation_mode = False + # self.single_annotation_mode = True + show_info("Switched to Single Annotation mode.") + elif index == 1: # Multi Annotation + self.multi_annotation_mode = True + # self.single_annotation_mode = False + show_info("Switched to Multi Annotation mode.") + def _on_save_csv_click(self): """ Is called whenever Save features button is clicked """ bboxes = self.viewer.layers[self.save_layer_name].data @@ -421,16 +508,57 @@ def _on_save_json_click(self): """ Is called whenever Save boxes button is clicked """ bboxes = self.viewer.layers[self.save_layer_name].data #scores = #add - if not bboxes: show_info('No organoids detected! Please run auto organoid counter or run algorithm first and try again!') + if not bboxes: + show_info('No organoids detected! Please run auto organoid counter or run algorithm first and try again!') + return + + # Check for multi-annotation mode + if self.multi_annotation_mode: + + # Get the edge colors for all bounding boxes + edge_colors = self.cur_shapes_layer.edge_color + labels = [] + + # Check if all bounding boxes have their edge color set (not green or blue) + green = np.array(settings.COLOR_CLASS_1) + blue = np.array(settings.COLOR_CLASS_2) + + all_colored = True + for edge_color in edge_colors: + # Compare the colors with a tolerance using np.allclose to account for floating-point errors + if not (np.allclose(edge_color[:3], green[:3]) or np.allclose(edge_color[:3], blue[:3])): + all_colored = False + break + + if not all_colored: + show_error('Please change the color of all bounding boxes before saving.') + return + + # Assign organoid label based on edge_color + for edge_color in edge_colors: + if np.allclose(edge_color[:3], green[:3]): + labels.append(0) # Label for green + elif np.allclose(edge_color[:3], blue[:3]): + labels.append(1) # Label for blue + else: + raise ValueError(f"Unexpected edge color {edge_color[:3]} encountered.") + + #elif self.single_annotation_mode: else: - data_json = utils.get_bboxes_as_dict(bboxes, - self.viewer.layers[self.save_layer_name].properties['box_id'], - self.viewer.layers[self.save_layer_name].properties['scores'], - self.viewer.layers[self.save_layer_name].scale) - # write bbox coordinates to json - fd = QFileDialog() - name,_ = fd.getSaveFileName(self, 'Save File', self.save_layer_name, 'JSON files (*.json)')#, 'CSV Files (*.csv)') - if name: utils.write_to_json(name, data_json) + # Single annotation mode: all bounding boxes get a default label + labels = [0] * len(bboxes) # Default label for single annotation mode + + data_json = utils.get_bboxes_as_dict(bboxes, + self.viewer.layers[self.save_layer_name].properties['box_id'], + self.viewer.layers[self.save_layer_name].properties['scores'], + self.viewer.layers[self.save_layer_name].scale, + labels=labels) + + + # write bbox coordinates to json + fd = QFileDialog() + name,_ = fd.getSaveFileName(self, 'Save File', self.save_layer_name, 'JSON files (*.json)')#, 'CSV Files (*.csv)') + if name: utils.write_to_json(name, data_json) def _update_added_image(self, added_items): """ @@ -471,7 +599,8 @@ def _update_added_shapes(self, added_items): self.organoiDL.update_bboxes_scores(self.cur_shapes_name, self.cur_shapes_layer.data, self.cur_shapes_layer.properties['scores'], - self.cur_shapes_layer.properties['box_id']) + self.cur_shapes_layer.properties['box_id'] + ) self.cur_shapes_layer.events.data.connect(self.shapes_event_handler) def _update_remove_shapes(self, removed_layers): @@ -507,6 +636,7 @@ def shapes_event_handler(self, event): new_ids[-1] = self.organoiDL.next_id[self.cur_shapes_name] new_scores = self.viewer.layers[self.cur_shapes_name].properties['scores'] new_scores[-1] = 1 + # set new properties to shapes layer self.viewer.layers[self.cur_shapes_name].properties ={'box_id': new_ids,'scores': new_scores} # refresh text displayed @@ -530,6 +660,7 @@ def _setup_input_widget(self): window_sizes_box = self._setup_window_sizes_box() downsampling_box = self._setup_downsampling_box() run_box = self._setup_run_box() + annotation_mode_box = self._setup_annotation_mode_box() # Annotation mode dropdown to select single or multi-annotation self._setup_progress_box() # and add all these to the layout @@ -540,6 +671,7 @@ def _setup_input_widget(self): vbox.addLayout(window_sizes_box) vbox.addLayout(downsampling_box) vbox.addLayout(run_box) + vbox.addLayout(annotation_mode_box) # Add the annotation dropdown vbox.addWidget(self.progress_box) input_widget.setLayout(vbox) return input_widget @@ -682,6 +814,24 @@ def _setup_run_box(self): hbox.addWidget(run_btn) hbox.addStretch(1) return hbox + + def _setup_annotation_mode_box(self): + """ + Sets up the GUI part where the annotation mode is selected. + """ + hbox = QHBoxLayout() + + # Label + annotation_mode_label = QLabel("Annotation Mode:", self) + hbox.addWidget(annotation_mode_label) + + # Dropdown + self.annotation_mode_dropdown = QComboBox() + self.annotation_mode_dropdown.addItems(["Single Annotation", "Multi Annotation"]) + self.annotation_mode_dropdown.currentIndexChanged.connect(self.on_annotation_mode_changed) + hbox.addWidget(self.annotation_mode_dropdown) + + return hbox def _setup_progress_box(self): """ diff --git a/napari_organoid_counter/settings.py b/napari_organoid_counter/settings.py index b8b4039..2eef96f 100644 --- a/napari_organoid_counter/settings.py +++ b/napari_organoid_counter/settings.py @@ -38,7 +38,18 @@ def init(): "rtmdet": {"source": "https://zenodo.org/records/11388549/files/rtmdet_l_organoid.py", "destination": ".mim/configs/rtmdet/rtmdet_l_organoid.py" } + } + + # Add color definitions + global COLOR_CLASS_1 + COLOR_CLASS_1 = [85 / 255, 1.0, 0, 1.0] # Green + + global COLOR_CLASS_2 + COLOR_CLASS_2 = [0, 29 / 255, 1.0, 1.0] # Blue + + global COLOR_DEFAULT + COLOR_DEFAULT = [1., 0, 1., 1.] # Magenta