Skip to content

Commit

Permalink
Allow turning off classification or detection in GUI (#402)
Browse files Browse the repository at this point in the history
* Allow turning off classification or detection in GUI.

* Fix test.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Refactor to fix code analysis errors.

* Ensure array is always 2d.

* Apply suggestions from code review

Co-authored-by: Igor Tatarnikov <[email protected]>

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Igor Tatarnikov <[email protected]>
  • Loading branch information
3 people authored Apr 22, 2024
1 parent 560909f commit e6a887f
Show file tree
Hide file tree
Showing 6 changed files with 396 additions and 143 deletions.
92 changes: 51 additions & 41 deletions cellfinder/core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Callable, List, Optional, Tuple

import numpy as np
from brainglobe_utils.cells.cells import Cell
from brainglobe_utils.general.logging import suppress_specific_logs

from cellfinder.core import logger
Expand Down Expand Up @@ -42,6 +43,9 @@ def main(
cube_height: int = 50,
cube_depth: int = 20,
network_depth: depth_type = "50",
skip_detection: bool = False,
skip_classification: bool = False,
detected_cells: List[Cell] = None,
*,
detect_callback: Optional[Callable[[int], None]] = None,
classify_callback: Optional[Callable[[int], None]] = None,
Expand All @@ -65,52 +69,58 @@ def main(
from cellfinder.core.detect import detect
from cellfinder.core.tools import prep

logger.info("Detecting cell candidates")
if not skip_detection:
logger.info("Detecting cell candidates")

points = detect.main(
signal_array,
start_plane,
end_plane,
voxel_sizes,
soma_diameter,
max_cluster_size,
ball_xy_size,
ball_z_size,
ball_overlap_fraction,
soma_spread_factor,
n_free_cpus,
log_sigma_size,
n_sds_above_mean_thresh,
callback=detect_callback,
)

if detect_finished_callback is not None:
detect_finished_callback(points)

install_path = None
model_weights = prep.prep_model_weights(
model_weights, install_path, model, n_free_cpus
)
if len(points) > 0:
logger.info("Running classification")
points = classify.main(
points,
points = detect.main(
signal_array,
background_array,
n_free_cpus,
start_plane,
end_plane,
voxel_sizes,
network_voxel_sizes,
batch_size,
cube_height,
cube_width,
cube_depth,
trained_model,
model_weights,
network_depth,
callback=classify_callback,
soma_diameter,
max_cluster_size,
ball_xy_size,
ball_z_size,
ball_overlap_fraction,
soma_spread_factor,
n_free_cpus,
log_sigma_size,
n_sds_above_mean_thresh,
callback=detect_callback,
)

if detect_finished_callback is not None:
detect_finished_callback(points)
else:
logger.info("No candidates, skipping classification")
points = detected_cells or [] # if None
detect_finished_callback(points)

if not skip_classification:
install_path = None
model_weights = prep.prep_model_weights(
model_weights, install_path, model, n_free_cpus
)
if len(points) > 0:
logger.info("Running classification")
points = classify.main(
points,
signal_array,
background_array,
n_free_cpus,
voxel_sizes,
network_voxel_sizes,
batch_size,
cube_height,
cube_width,
cube_depth,
trained_model,
model_weights,
network_depth,
callback=classify_callback,
)
else:
logger.info("No candidates, skipping classification")

return points


Expand Down
Loading

0 comments on commit e6a887f

Please sign in to comment.