Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delete old training data saves and warn user #365

Merged
merged 3 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 54 additions & 17 deletions cellfinder/napari/curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
import tifffile
from brainglobe_napari_io.cellfinder.utils import convert_layer_to_cells
from brainglobe_utils.cells.cells import Cell
from brainglobe_utils.general.system import delete_directory_contents
from brainglobe_utils.IO.yaml import save_yaml
from brainglobe_utils.qtpy.dialog import display_warning
from brainglobe_utils.qtpy.interaction import add_button, add_combobox
from magicgui.widgets import ProgressBar
from napari.qt.threading import thread_worker
from napari.utils.notifications import show_info
Expand All @@ -20,8 +23,6 @@
QWidget,
)

from .utils import add_button, add_combobox, display_question

# Constants used throughout
WINDOW_HEIGHT = 750
WINDOW_WIDTH = 1500
Expand Down Expand Up @@ -173,33 +174,33 @@ def add_loading_panel(self, row: int, column: int = 0):
self.load_data_layout,
"Training_data (non_cells)",
self.point_layer_names,
4,
row=4,
callback=self.set_training_data_non_cell,
)
self.mark_as_cell_button = add_button(
"Mark as cell(s)",
self.load_data_layout,
self.mark_as_cell,
5,
row=5,
)
self.mark_as_non_cell_button = add_button(
"Mark as non cell(s)",
self.load_data_layout,
self.mark_as_non_cell,
5,
row=5,
column=1,
)
self.add_training_data_button = add_button(
"Add training data layers",
self.load_data_layout,
self.add_training_data,
6,
row=6,
)
self.save_training_data_button = add_button(
"Save training data",
self.load_data_layout,
self.save_training_data,
6,
row=6,
column=1,
)
self.load_data_layout.setColumnMinimumWidth(0, COLUMN_WIDTH)
Expand Down Expand Up @@ -256,7 +257,7 @@ def add_training_data(self):

overwrite = False
if self.training_data_cell_layer or self.training_data_non_cell_layer:
overwrite = display_question(
overwrite = display_warning(
self,
"Training data layers exist",
"Training data layers already exist, "
Expand Down Expand Up @@ -363,7 +364,10 @@ def mark_point_as_type(self, point_type: str):
)

def save_training_data(
self, *, block: bool = False, prompt_for_directory: bool = True
self,
*,
block: bool = False,
prompt_for_directory: bool = True,
) -> None:
"""
Parameters
Expand All @@ -373,16 +377,45 @@ def save_training_data(
prompt_for_directory :
If `True` show a file dialog for the user to select a directory.
"""

if self.is_data_extractable():
if prompt_for_directory:
self.get_output_directory()
# if the directory is not empty
if any(self.output_directory.iterdir()):
choice = display_warning(
self,
"About to save training data",
"Existing files will be will be deleted. Proceed?",
)
if not choice:
return
if self.output_directory is not None:
self.__prep_directories_for_save()
self.__extract_cubes(block=block)
self.__save_yaml_file()
show_info("Done")

self.update_status_label("Ready")

def __prep_directories_for_save(self):
self.yaml_filename = self.output_directory / "training.yml"
self.cell_cube_dir = self.output_directory / "cells"
self.no_cell_cube_dir = self.output_directory / "non_cells"

self.__delete_existing_saved_training_data()

def __delete_existing_saved_training_data(self):
self.yaml_filename.unlink(missing_ok=True)
for directory in (
self.cell_cube_dir,
self.no_cell_cube_dir,
):
if directory.exists():
delete_directory_contents(directory)
else:
directory.mkdir(exist_ok=True, parents=True)

def __extract_cubes(self, *, block=False):
"""
Parameters
Expand Down Expand Up @@ -489,18 +522,16 @@ def convert_layers_to_cells(self):
self.non_cells_to_extract = list(set(self.non_cells_to_extract))

def __save_yaml_file(self):
# TODO: implement this in a portable way
yaml_filename = self.output_directory / "training.yml"
yaml_section = [
{
"cube_dir": str(self.output_directory / "cells"),
"cube_dir": str(self.cell_cube_dir),
"cell_def": "",
"type": "cell",
"signal_channel": 0,
"bg_channel": 1,
},
{
"cube_dir": str(self.output_directory / "non_cells"),
"cube_dir": str(self.no_cell_cube_dir),
"cell_def": "",
"type": "no_cell",
"signal_channel": 0,
Expand All @@ -509,7 +540,7 @@ def __save_yaml_file(self):
]

yaml_contents = {"data": yaml_section}
save_yaml(yaml_contents, yaml_filename)
save_yaml(yaml_contents, self.yaml_filename)

def update_progress(self, attributes: dict):
"""
Expand Down Expand Up @@ -538,9 +569,15 @@ def extract_cubes(self):
"non_cells": self.non_cells_to_extract,
}

for cell_type, cell_list in to_extract.items():
cell_type_output_directory = self.output_directory / cell_type
cell_type_output_directory.mkdir(exist_ok=True, parents=True)
directories = {
"cells": self.cell_cube_dir,
"non_cells": self.no_cell_cube_dir,
}

for cell_type in ["cells", "non_cells"]:
cell_type_output_directory = directories[cell_type]
cell_list = to_extract[cell_type]

self.update_status_label(f"Saving {cell_type}...")

cube_generator = CubeGeneratorFromFile(
Expand Down
90 changes: 1 addition & 89 deletions cellfinder/napari/utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
from typing import Callable, List, Optional, Tuple
from typing import List, Tuple

import napari
import numpy as np
import pandas as pd
from brainglobe_utils.cells.cells import Cell
from pkg_resources import resource_filename
from qtpy.QtWidgets import (
QComboBox,
QLabel,
QLayout,
QMessageBox,
QPushButton,
QWidget,
)

brainglobe_logo = resource_filename(
"cellfinder", "napari/images/brainglobe.png"
Expand Down Expand Up @@ -98,83 +90,3 @@ def cells_to_array(cells: List[Cell]) -> Tuple[np.ndarray, np.ndarray]:
points = cells_df_as_np(df[df["type"] == Cell.CELL])
rejected = cells_df_as_np(df[df["type"] == Cell.UNKNOWN])
return points, rejected


def add_combobox(
layout: QLayout,
label: str,
items: List[str],
row: int,
column: int = 0,
label_stack: bool = False,
callback=None,
width: int = 150,
) -> Tuple[QComboBox, Optional[QLabel]]:
"""
Add a selection box to *layout*.
"""
if label_stack:
combobox_row = row + 1
combobox_column = column
else:
combobox_row = row
combobox_column = column + 1
combobox = QComboBox()
combobox.addItems(items)
if callback:
combobox.currentIndexChanged.connect(callback)
combobox.setMaximumWidth = width

if label is not None:
combobox_label = QLabel(label)
combobox_label.setMaximumWidth = width
layout.addWidget(combobox_label, row, column)
else:
combobox_label = None

layout.addWidget(combobox, combobox_row, combobox_column)
return combobox, combobox_label


def add_button(
label: str,
layout: QLayout,
connected_function: Callable,
row: int,
column: int = 0,
visibility: bool = True,
minimum_width: int = 0,
alignment: str = "center",
) -> QPushButton:
"""
Add a button to *layout*.
"""
button = QPushButton(label)
if alignment == "center":
pass
elif alignment == "left":
button.setStyleSheet("QPushButton { text-align: left; }")
elif alignment == "right":
button.setStyleSheet("QPushButton { text-align: right; }")

button.setVisible(visibility)
button.setMinimumWidth(minimum_width)
layout.addWidget(button, row, column)
button.clicked.connect(connected_function)
return button


def display_question(widget: QWidget, title: str, message: str) -> bool:
"""
Display a warning in a pop up that informs about overwriting files.
"""
message_reply = QMessageBox.question(
widget,
title,
message,
QMessageBox.Yes | QMessageBox.Cancel,
)
if message_reply == QMessageBox.Yes:
return True
else:
return False
39 changes: 0 additions & 39 deletions tests/napari/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import pytest
from brainglobe_utils.cells.cells import Cell
from qtpy.QtWidgets import QGridLayout

from cellfinder.napari.utils import (
add_button,
add_combobox,
add_layers,
html_label_widget,
)
Expand All @@ -27,38 +23,3 @@ def test_html_label_widget():
label_widget = html_label_widget("A nice label", tag="h1")
assert label_widget["widget_type"] == "Label"
assert label_widget["label"] == "<h1>A nice label</h1>"


@pytest.mark.parametrize("label_stack", [True, False])
@pytest.mark.parametrize("label", ["A label", None])
def test_add_combobox(label, label_stack):
"""
Smoke test for add_combobox for all conditional branches
"""
layout = QGridLayout()
combobox = add_combobox(
layout,
row=0,
label=label,
items=["item 1", "item 2"],
label_stack=label_stack,
)
assert combobox is not None


@pytest.mark.parametrize(
argnames="alignment", argvalues=["center", "left", "right"]
)
def test_add_button(alignment):
"""
Smoke tests for add_button for all conditional branches
"""
layout = QGridLayout()
button = add_button(
layout=layout,
connected_function=lambda: None,
label="A button",
row=0,
alignment=alignment,
)
assert button is not None