From e6a887f1721b89d51328214b108e0da5401a24a9 Mon Sep 17 00:00:00 2001 From: Matt Einhorn Date: Mon, 22 Apr 2024 06:44:02 -0400 Subject: [PATCH] Allow turning off classification or detection in GUI (#402) * Allow turning off classification or detection in GUI. * Fix test. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactor to fix code analysis errors. * Ensure array is always 2d. * Apply suggestions from code review Co-authored-by: Igor Tatarnikov <61896994+IgorTatarnikov@users.noreply.github.com> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Igor Tatarnikov <61896994+IgorTatarnikov@users.noreply.github.com> --- cellfinder/core/main.py | 92 +++--- cellfinder/napari/detect/detect.py | 264 +++++++++++++----- cellfinder/napari/detect/detect_containers.py | 10 +- cellfinder/napari/detect/thread_worker.py | 14 + cellfinder/napari/utils.py | 100 +++++-- tests/napari/test_utils.py | 59 +++- 6 files changed, 396 insertions(+), 143 deletions(-) diff --git a/cellfinder/core/main.py b/cellfinder/core/main.py index cb78cca4..c74a9d44 100644 --- a/cellfinder/core/main.py +++ b/cellfinder/core/main.py @@ -7,6 +7,7 @@ from typing import Callable, List, Optional, Tuple import numpy as np +from brainglobe_utils.cells.cells import Cell from brainglobe_utils.general.logging import suppress_specific_logs from cellfinder.core import logger @@ -42,6 +43,9 @@ def main( cube_height: int = 50, cube_depth: int = 20, network_depth: depth_type = "50", + skip_detection: bool = False, + skip_classification: bool = False, + detected_cells: List[Cell] = None, *, detect_callback: Optional[Callable[[int], None]] = None, classify_callback: Optional[Callable[[int], None]] = None, @@ -65,52 +69,58 @@ def main( from cellfinder.core.detect import detect from cellfinder.core.tools import prep - logger.info("Detecting cell candidates") + if not skip_detection: + logger.info("Detecting cell candidates") - points = detect.main( - signal_array, - start_plane, - end_plane, - voxel_sizes, - soma_diameter, - max_cluster_size, - ball_xy_size, - ball_z_size, - ball_overlap_fraction, - soma_spread_factor, - n_free_cpus, - log_sigma_size, - n_sds_above_mean_thresh, - callback=detect_callback, - ) - - if detect_finished_callback is not None: - detect_finished_callback(points) - - install_path = None - model_weights = prep.prep_model_weights( - model_weights, install_path, model, n_free_cpus - ) - if len(points) > 0: - logger.info("Running classification") - points = classify.main( - points, + points = detect.main( signal_array, - background_array, - n_free_cpus, + start_plane, + end_plane, voxel_sizes, - network_voxel_sizes, - batch_size, - cube_height, - cube_width, - cube_depth, - trained_model, - model_weights, - network_depth, - callback=classify_callback, + soma_diameter, + max_cluster_size, + ball_xy_size, + ball_z_size, + ball_overlap_fraction, + soma_spread_factor, + n_free_cpus, + log_sigma_size, + n_sds_above_mean_thresh, + callback=detect_callback, ) + + if detect_finished_callback is not None: + detect_finished_callback(points) else: - logger.info("No candidates, skipping classification") + points = detected_cells or [] # if None + detect_finished_callback(points) + + if not skip_classification: + install_path = None + model_weights = prep.prep_model_weights( + model_weights, install_path, model, n_free_cpus + ) + if len(points) > 0: + logger.info("Running classification") + points = classify.main( + points, + signal_array, + background_array, + n_free_cpus, + voxel_sizes, + network_voxel_sizes, + batch_size, + cube_height, + cube_width, + cube_depth, + trained_model, + model_weights, + network_depth, + callback=classify_callback, + ) + else: + logger.info("No candidates, skipping classification") + return points diff --git a/cellfinder/napari/detect/detect.py b/cellfinder/napari/detect/detect.py index eaaab20f..cdf36939 100644 --- a/cellfinder/napari/detect/detect.py +++ b/cellfinder/napari/detect/detect.py @@ -1,8 +1,11 @@ +from functools import partial from math import ceil from pathlib import Path -from typing import Optional +from typing import Any, Callable, Dict, Optional, Tuple import napari +import napari.layers +from brainglobe_utils.cells.cells import Cell from magicgui import magicgui from magicgui.widgets import FunctionGui, ProgressBar from napari.utils.notifications import show_info @@ -10,9 +13,11 @@ from cellfinder.core.classify.cube_generator import get_cube_depth_min_max from cellfinder.napari.utils import ( - add_layers, + add_classified_layers, + add_single_layer, cellfinder_header, html_label_widget, + napari_array_to_cells, ) from .detect_containers import ( @@ -32,16 +37,10 @@ MIN_PLANES_ANALYSE = 0 -def detect_widget() -> FunctionGui: - """ - Create a detection plugin GUI. - """ - progress_bar = ProgressBar() - - # options that is filled in from the gui - options = {"signal_image": None, "background_image": None, "viewer": None} - - # signal and background images are separated out from the main magicgui +def get_heavy_widgets( + options: Dict[str, Any] +) -> Tuple[Callable, Callable, Callable]: + # signal and other input are separated out from the main magicgui # parameter selections and are inserted as widget children in their own # sub-containers of the root. Because if these image parameters are # included in the root widget, every time *any* parameter updates, the gui @@ -91,6 +90,140 @@ def background_image_opt( """ options["background_image"] = background_image + @magicgui( + call_button=False, + persist=False, + scrollable=False, + labels=False, + auto_call=True, + ) + def cell_layer_opt( + cell_layer: napari.layers.Points, + ): + """ + magicgui widget for setting the cell layer input when detection is + skipped. + + Parameters + ---------- + cell_layer : napari.layers.Points + If detection is skipped, select the cell layer containing the + detected cells to use for classification + """ + options["cell_layer"] = cell_layer + + return signal_image_opt, background_image_opt, cell_layer_opt + + +def add_heavy_widgets( + root: FunctionGui, + widgets: Tuple[FunctionGui, ...], + new_names: Tuple[str, ...], + insertions: Tuple[str, ...], +) -> None: + for widget, new_name, insertion in zip(widgets, new_names, insertions): + # make it look as if it's directly in the root container + widget.margins = 0, 0, 0, 0 + # the parameters of these widgets are updated using `auto_call` only. + # If False, magicgui passes these as args to root() when the root's + # function runs. But that doesn't list them as args of its function + widget.gui_only = True + root.insert(root.index(insertion) + 1, widget) + getattr(root, widget.name).label = new_name + + +def restore_options_defaults(widget: FunctionGui) -> None: + """ + Restore default widget values. + """ + defaults = { + **DataInputs.defaults(), + **DetectionInputs.defaults(), + **ClassificationInputs.defaults(), + **MiscInputs.defaults(), + } + for name, value in defaults.items(): + if value is not None: # ignore fields with no default + getattr(widget, name).value = value + + +def get_results_callback( + skip_classification: bool, viewer: napari.Viewer +) -> Callable: + """ + Returns the callback that is connected to output of the pipeline. + It returns the detected points that we have to visualize. + """ + if skip_classification: + # after detection w/o classification, everything is unknown + def done_func(points): + add_single_layer( + points, + viewer=viewer, + name="Cell candidates", + cell_type=Cell.UNKNOWN, + ) + + else: + # after classification we have either cell or unknown + def done_func(points): + add_classified_layers( + points, + viewer=viewer, + unknown_name="Rejected", + cell_name="Detected", + ) + + return done_func + + +def find_local_planes( + viewer: napari.Viewer, + voxel_size_z: float, + signal_image: napari.layers.Image, +) -> Tuple[int, int]: + """ + When detecting only locally, it returns the start and end planes to use. + """ + current_plane = viewer.dims.current_step[0] + + # so a reasonable number of cells in the plane are detected + planes_needed = MIN_PLANES_ANALYSE + int( + ceil((CUBE_DEPTH * NETWORK_VOXEL_SIZES[0]) / voxel_size_z) + ) + + start_plane, end_plane = get_cube_depth_min_max( + current_plane, planes_needed + ) + start_plane = max(0, start_plane) + end_plane = min(len(signal_image.data), end_plane) + + return start_plane, end_plane + + +def reraise(e: Exception) -> None: + """Re-raises the exception.""" + raise Exception from e + + +def detect_widget() -> FunctionGui: + """ + Create a detection plugin GUI. + """ + progress_bar = ProgressBar() + + # options that is filled in from the gui + options = { + "signal_image": None, + "background_image": None, + "viewer": None, + "cell_layer": None, + } + + signal_image_opt, background_image_opt, cell_layer_opt = get_heavy_widgets( + options + ) + @magicgui( detection_label=html_label_widget("Cell detection", tag="h3"), **DataInputs.widget_representation(), @@ -109,6 +242,7 @@ def widget( voxel_size_y: float, voxel_size_x: float, detection_options, + skip_detection: bool, soma_diameter: float, ball_xy_size: float, ball_z_size: float, @@ -118,6 +252,7 @@ def widget( soma_spread_factor: float, max_cluster_size: int, classification_options, + skip_classification: bool, trained_model: Optional[Path], use_pre_trained_weights: bool, misc_options, @@ -139,6 +274,10 @@ def widget( Size of your voxels in the y direction (top to bottom) voxel_size_x : float Size of your voxels in the x direction (left to right) + skip_detection : bool + If selected, the detection step is skipped and instead we get the + detected cells from the cell layer below (from a previous + detection run or import) soma_diameter : float The expected in-plane soma diameter (microns) ball_xy_size : float @@ -159,6 +298,9 @@ def widget( should be attempted use_pre_trained_weights : bool Select to use pre-trained model weights + skip_classification : bool + If selected, the classification step is skipped and all cells from + the detection stage are added trained_model : Optional[Path] Trained model file path (home directory (default) -> pretrained weights) @@ -184,24 +326,39 @@ def widget( # cellfinder plugin is fully open and initialized signal_image_opt() background_image_opt() + cell_layer_opt() signal_image = options["signal_image"] - background_image = options["background_image"] - viewer = options["viewer"] - if signal_image is None or background_image is None: + if signal_image is None or options["background_image"] is None: show_info("Both signal and background images must be specified.") return + detected_cells = [] + if skip_detection: + if options["cell_layer"] is None: + show_info( + "Skip detection selected, but no existing cell layer " + "is selected." + ) + return + + # set cells as unknown so that classification will process them + detected_cells = napari_array_to_cells( + options["cell_layer"], Cell.UNKNOWN + ) + data_inputs = DataInputs( signal_image.data, - background_image.data, + options["background_image"].data, voxel_size_z, voxel_size_y, voxel_size_x, ) detection_inputs = DetectionInputs( + skip_detection, + detected_cells, soma_diameter, ball_xy_size, ball_z_size, @@ -215,24 +372,15 @@ def widget( if use_pre_trained_weights: trained_model = None classification_inputs = ClassificationInputs( - use_pre_trained_weights, trained_model + skip_classification, use_pre_trained_weights, trained_model ) - end_plane = len(signal_image.data) if end_plane == 0 else end_plane - if analyse_local: - current_plane = viewer.dims.current_step[0] - - # so a reasonable number of cells in the plane are detected - planes_needed = MIN_PLANES_ANALYSE + int( - ceil((CUBE_DEPTH * NETWORK_VOXEL_SIZES[0]) / voxel_size_z) - ) - - start_plane, end_plane = get_cube_depth_min_max( - current_plane, planes_needed + start_plane, end_plane = find_local_planes( + options["viewer"], voxel_size_z, signal_image ) - start_plane = max(0, start_plane) - end_plane = min(len(signal_image.data), end_plane) + elif not end_plane: + end_plane = len(signal_image.data) misc_inputs = MiscInputs( start_plane, end_plane, n_free_cpus, analyse_local, debug @@ -244,58 +392,34 @@ def widget( classification_inputs, misc_inputs, ) + worker.returned.connect( - lambda points: add_layers(points, viewer=viewer) + get_results_callback(skip_classification, options["viewer"]) ) - # Make sure if the worker emits an error, it is propagated to this # thread - def reraise(e): - raise Exception from e - worker.errored.connect(reraise) + worker.connect_progress_bar_callback(progress_bar) - def update_progress_bar(label: str, max: int, value: int): - progress_bar.label = label - progress_bar.max = max - progress_bar.value = value - - worker.update_progress_bar.connect(update_progress_bar) worker.start() widget.native.layout().insertWidget(0, cellfinder_header()) - @widget.reset_button.changed.connect - def restore_defaults(): - """ - Restore default widget values. - """ - defaults = { - **DataInputs.defaults(), - **DetectionInputs.defaults(), - **ClassificationInputs.defaults(), - **MiscInputs.defaults(), - } - for name, value in defaults.items(): - if value is not None: # ignore fields with no default - getattr(widget, name).value = value + # reset restores defaults + widget.reset_button.changed.connect( + partial(restore_options_defaults, widget) + ) # Insert progress bar before the run and reset buttons - widget.insert(-3, progress_bar) - - # add the signal and background image parameters - # make it look as if it's directly in the root container - signal_image_opt.margins = 0, 0, 0, 0 - # the parameters are updated using `auto_call` only. If False, magicgui - # passes these as args to widget(), which doesn't list them as args - signal_image_opt.gui_only = True - widget.insert(3, signal_image_opt) - widget.signal_image_opt.label = "Signal image" - - background_image_opt.margins = 0, 0, 0, 0 - background_image_opt.gui_only = True - widget.insert(4, background_image_opt) - widget.background_image_opt.label = "Background image" + widget.insert(widget.index("debug") + 1, progress_bar) + + # add the signal and background image etc. + add_heavy_widgets( + widget, + (background_image_opt, signal_image_opt, cell_layer_opt), + ("Background image", "Signal image", "Candidate cell layer"), + ("voxel_size_z", "voxel_size_z", "soma_diameter"), + ) scroll = QScrollArea() scroll.setWidget(widget._widget._qwidget) diff --git a/cellfinder/napari/detect/detect_containers.py b/cellfinder/napari/detect/detect_containers.py index 824a2a0b..39fda163 100644 --- a/cellfinder/napari/detect/detect_containers.py +++ b/cellfinder/napari/detect/detect_containers.py @@ -1,8 +1,9 @@ from dataclasses import dataclass from pathlib import Path -from typing import Optional +from typing import List, Optional import numpy +from brainglobe_utils.cells.cells import Cell from cellfinder.napari.input_container import InputContainer from cellfinder.napari.utils import html_label_widget @@ -59,6 +60,8 @@ def widget_representation(cls) -> dict: class DetectionInputs(InputContainer): """Container for cell candidate detection inputs.""" + skip_detection: bool = False + detected_cells: Optional[List[Cell]] = None soma_diameter: float = 16.0 ball_xy_size: float = 6 ball_z_size: float = 15 @@ -75,6 +78,7 @@ def as_core_arguments(self) -> dict: def widget_representation(cls) -> dict: return dict( detection_options=html_label_widget("Detection:"), + skip_detection=dict(value=cls.defaults()["skip_detection"]), soma_diameter=cls._custom_widget("soma_diameter"), ball_xy_size=cls._custom_widget( "ball_xy_size", custom_label="Ball filter (xy)" @@ -107,6 +111,7 @@ def widget_representation(cls) -> dict: class ClassificationInputs(InputContainer): """Container for classification inputs.""" + skip_classification: bool = False use_pre_trained_weights: bool = True trained_model: Optional[Path] = Path.home() @@ -123,6 +128,9 @@ def widget_representation(cls) -> dict: value=cls.defaults()["use_pre_trained_weights"] ), trained_model=dict(value=cls.defaults()["trained_model"]), + skip_classification=dict( + value=cls.defaults()["skip_classification"] + ), ) diff --git a/cellfinder/napari/detect/thread_worker.py b/cellfinder/napari/detect/thread_worker.py index ea44dded..c4392860 100644 --- a/cellfinder/napari/detect/thread_worker.py +++ b/cellfinder/napari/detect/thread_worker.py @@ -1,3 +1,4 @@ +from magicgui.widgets import ProgressBar from napari.qt.threading import WorkerBase, WorkerBaseSignals from qtpy.QtCore import Signal @@ -41,6 +42,19 @@ def __init__( self.classification_inputs = classification_inputs self.misc_inputs = misc_inputs + def connect_progress_bar_callback(self, progress_bar: ProgressBar): + """ + Connects the progress bar to the work so that updates are shown on + the bar. + """ + + def update_progress_bar(label: str, max: int, value: int): + progress_bar.label = label + progress_bar.max = max + progress_bar.value = value + + self.update_progress_bar.connect(update_progress_bar) + def work(self) -> list: self.update_progress_bar.emit("Setting up detection...", 1, 0) diff --git a/cellfinder/napari/utils.py b/cellfinder/napari/utils.py index d48b81fb..2f853fe0 100644 --- a/cellfinder/napari/utils.py +++ b/cellfinder/napari/utils.py @@ -1,8 +1,8 @@ from typing import List, Tuple import napari +import napari.layers import numpy as np -import pandas as pd from brainglobe_utils.cells.cells import Cell from brainglobe_utils.qtpy.logo import header_widget @@ -31,16 +31,28 @@ def cellfinder_header(): ) -def add_layers(points: List[Cell], viewer: napari.Viewer) -> None: +# the xyz axis order in napari relative to ours. I.e. our zeroth axis is the +# napari last axis. Ours is XYZ. +napari_points_axis_order = 2, 1, 0 +# the xyz axis order in brainglobe relative to napari. I.e. napari's zeroth +# axis is our last axis - it's just flipped +brainglobe_points_axis_order = napari_points_axis_order + + +def add_classified_layers( + points: List[Cell], + viewer: napari.Viewer, + unknown_name: str = "Rejected", + cell_name: str = "Detected", +) -> None: """ - Adds classified cell candidates as two separate point layers to the napari - viewer. + Adds cell candidates as two separate point layers - unknowns and cells, to + the napari viewer. Does not add any other cell types, only Cell.UNKNOWN + and Cell.CELL from the list of cells. """ - detected, rejected = cells_to_array(points) - viewer.add_points( - rejected, - name="Rejected", + cells_to_array(points, Cell.UNKNOWN, napari_order=True), + name=unknown_name, size=15, n_dimensional=True, opacity=0.6, @@ -50,8 +62,8 @@ def add_layers(points: List[Cell], viewer: napari.Viewer) -> None: metadata=dict(point_type=Cell.UNKNOWN), ) viewer.add_points( - detected, - name="Detected", + cells_to_array(points, Cell.CELL, napari_order=True), + name=cell_name, size=15, n_dimensional=True, opacity=0.6, @@ -61,23 +73,61 @@ def add_layers(points: List[Cell], viewer: napari.Viewer) -> None: ) -def cells_df_as_np( - cells_df: pd.DataFrame, - new_order: List[int] = [2, 1, 0], - type_column: str = "type", +def add_single_layer( + points: List[Cell], + viewer: napari.Viewer, + name: str, + cell_type: int, +) -> None: + """ + Adds all cells of cell_type Cell.TYPE to a new point layer in the napari + viewer, with given name. + """ + viewer.add_points( + cells_to_array(points, cell_type, napari_order=True), + name=name, + size=15, + n_dimensional=True, + opacity=0.6, + symbol="ring", + face_color="lightskyblue", + visible=True, + metadata=dict(point_type=cell_type), + ) + + +def cells_to_array( + cells: List[Cell], cell_type: int, napari_order: bool = True ) -> np.ndarray: """ - Convert a dataframe to an array, dropping *type_column* and re-ordering - the columns with *new_order*. + Converts all the cells of the given type as a 2D pos array. + The column order is either XYZ, otherwise it's the napari ordering + of the 3 axes (napari_points_axis_order). """ - cells_df = cells_df.drop(columns=[type_column]) - cells = cells_df[cells_df.columns[new_order]] - cells = cells.to_numpy() - return cells + cells = [c for c in cells if c.type == cell_type] + if not cells: + # make sure we return 2d array if cells is empty + return np.zeros((0, 3), dtype=np.int_) + points = np.array([(c.x, c.y, c.z) for c in cells]) + + if napari_order: + return points[:, napari_points_axis_order] + return points -def cells_to_array(cells: List[Cell]) -> Tuple[np.ndarray, np.ndarray]: - df = pd.DataFrame([c.to_dict() for c in cells]) - points = cells_df_as_np(df[df["type"] == Cell.CELL]) - rejected = cells_df_as_np(df[df["type"] == Cell.UNKNOWN]) - return points, rejected +def napari_array_to_cells( + points: napari.layers.Points, + cell_type: int, + brainglobe_order: Tuple[int, int, int] = brainglobe_points_axis_order, +) -> List[Cell]: + """ + Takes a napari Points layer and returns a list of cell objects, one for + each point in the layer. + """ + data = np.asarray(points.data)[:, brainglobe_order].tolist() + + cells = [] + for row in data: + cells.append(Cell(pos=row, cell_type=cell_type)) + + return cells diff --git a/tests/napari/test_utils.py b/tests/napari/test_utils.py index f8eff803..6cf864f3 100644 --- a/tests/napari/test_utils.py +++ b/tests/napari/test_utils.py @@ -1,22 +1,69 @@ +import numpy as np from brainglobe_utils.cells.cells import Cell from cellfinder.napari.utils import ( - add_layers, + add_classified_layers, + cells_to_array, html_label_widget, + napari_array_to_cells, + napari_points_axis_order, ) -def test_add_layers(make_napari_viewer): - """Smoke test for add_layers utility""" +def test_add_classified_layers(make_napari_viewer): + """Smoke test for add_classified_layers utility""" + cell_pos = [1, 2, 3] + unknown_pos = [4, 5, 6] points = [ - Cell(pos=[1, 2, 3], cell_type=Cell.CELL), - Cell(pos=[4, 5, 6], cell_type=Cell.UNKNOWN), + Cell(pos=cell_pos, cell_type=Cell.CELL), + Cell(pos=unknown_pos, cell_type=Cell.UNKNOWN), ] viewer = make_napari_viewer() n_layers = len(viewer.layers) - add_layers(points, viewer) # adds a "detected" and a "rejected layer" + # adds a "detected" and a "rejected layer" + add_classified_layers( + points, viewer, unknown_name="rejected", cell_name="accepted" + ) assert len(viewer.layers) == n_layers + 2 + # check names match + rej_layer = cell_layer = None + for layer in reversed(viewer.layers): + if layer.name == "accepted" and cell_layer is None: + cell_layer = layer + if layer.name == "rejected" and rej_layer is None: + rej_layer = layer + assert cell_layer is not None + assert rej_layer is not None + assert cell_layer.data is not None + assert rej_layer.data is not None + + # check data added in correct column order + # CELL types + cell_data = np.array([cell_pos]) + assert np.all( + cells_to_array(points, Cell.CELL, napari_order=False) == cell_data + ) + # convert to napari order and check it is in napari + cell_data = cell_data[:, napari_points_axis_order] + assert np.all(cell_layer.data == cell_data) + + # UNKNOWN type + rej_data = np.array([unknown_pos]) + assert np.all( + cells_to_array(points, Cell.UNKNOWN, napari_order=False) == rej_data + ) + # convert to napari order and check it is in napari + rej_data = rej_data[:, napari_points_axis_order] + assert np.all(rej_layer.data == rej_data) + + # get cells back from napari points + cells_again = napari_array_to_cells(cell_layer.data, cell_type=Cell.CELL) + cells_again.extend( + napari_array_to_cells(rej_layer.data, cell_type=Cell.UNKNOWN) + ) + assert cells_again == points + def test_html_label_widget(): """Simple unit test for the HTML Label widget"""