From 66c41ed8bf4fc5ade988fe612212f5a491022c14 Mon Sep 17 00:00:00 2001 From: healthonrails Date: Mon, 11 Nov 2024 09:40:03 -0500 Subject: [PATCH] Refactor: Move FlexibleWorker Class to Dedicated Workers Module --- annolid/gui/app.py | 147 +++++++++++----------- annolid/gui/label_file.py | 4 +- annolid/gui/widgets/caption.py | 1 + annolid/gui/workers.py | 67 ++++++++++ annolid/segmentation/cutie_vos/predict.py | 4 +- 5 files changed, 148 insertions(+), 75 deletions(-) create mode 100644 annolid/gui/workers.py diff --git a/annolid/gui/app.py b/annolid/gui/app.py index 8dc858d..48636cf 100644 --- a/annolid/gui/app.py +++ b/annolid/gui/app.py @@ -42,6 +42,7 @@ from labelme.utils import newAction from labelme.app import MainWindow from annolid.gui.shape import Shape +from annolid.gui.workers import FlexibleWorker import subprocess import requests from PIL import ImageQt @@ -77,39 +78,6 @@ LABEL_COLORMAP = imgviz.label_colormap(value=200) -class FlexibleWorker(QtCore.QObject): - start = QtCore.Signal() - finished = QtCore.Signal(object) - return_value = QtCore.Signal(object) - stop_signal = QtCore.Signal() - progress_changed = QtCore.Signal(int) - - def __init__(self, function, *args, **kwargs): - super(FlexibleWorker, self).__init__() - - self.function = function - self.args = args - self.kwargs = kwargs - self.stopped = False - - self.stop_signal.connect(self.stop) - - def run(self): - self.stopped = False - result = self.function(*self.args, **self.kwargs) - self.return_value.emit(result) - self.finished.emit(result) - - def stop(self): - self.stopped = True - - def is_stopped(self): - return self.stopped - - def progress_callback(self, progress): - self.progress_changed.emit(progress) - - class LoadFrameThread(QtCore.QObject): """ Thread for loading video frames. @@ -1180,14 +1148,14 @@ def predict_from_next_frame(self, or self.automatic_pause_enabled) if "sam2_hiera" in model_name: self.pred_worker = FlexibleWorker( - function=process_video, + task_function=process_video, video_path=self.video_file, frame_idx=self.frame_number, model_config='sam2_hiera_l.yaml' if 'hiera_l' in model_name else "sam2_hiera_s.yaml", ) else: self.pred_worker = FlexibleWorker( - function=self.video_processor.process_video_frames, + task_function=self.video_processor.process_video_frames, start_frame=self.frame_number+1, end_frame=end_frame, step=self.step_size, @@ -1206,13 +1174,13 @@ def predict_from_next_frame(self, "background-color: red; color: white;") self.stop_prediction_flag = True self.pred_worker.moveToThread(self.seg_pred_thread) - self.pred_worker.start.connect(self.pred_worker.run) - self.pred_worker.return_value.connect( + self.pred_worker.start_signal.connect(self.pred_worker.run) + self.pred_worker.result_signal.connect( self.lost_tracking_instance) - self.pred_worker.finished.connect(self.predict_is_ready) + self.pred_worker.finished_signal.connect(self.predict_is_ready) self.seg_pred_thread.finished.connect( self.seg_pred_thread.quit) - self.pred_worker.start.emit() + self.pred_worker.start_signal.emit() def lost_tracking_instance(self, message): if message is None: @@ -1604,41 +1572,78 @@ def frames(self): self.importDirImages(out_frames_dir) def convert_json_to_tracked_csv(self): - if self.video_file is not None: - video_file = self.video_file - out_folder = Path(video_file).with_suffix('') - if out_folder is None or not out_folder.exists(): - QtWidgets.QMessageBox.about(self, - "No predictions", - "Help Annolid achieve precise predictions by labeling a frame.\ - Your input is valuable!") + """ + Convert JSON annotations to a tracked CSV file and handle the progress using a separate thread. + """ + if not self.video_file: + QtWidgets.QMessageBox.warning( + self, "Missing Video File", "No video file selected.") + return + video_file = self.video_file + out_folder = Path(video_file).with_suffix('') + + if not out_folder or not out_folder.exists(): + QtWidgets.QMessageBox.warning( + self, + "No Predictions Found", + "Help Annolid achieve precise predictions by labeling a frame. Your input is valuable!" + ) return - def update_progress(progress): - self.progress_bar.setValue(progress) + self._initialize_progress_bar() + + try: + self.worker = FlexibleWorker( + task_function=labelme2csv.convert_json_to_csv, + json_folder=str(out_folder), + progress_callback=self._update_progress_bar + ) + self.thread = QtCore.QThread() + + # Move the worker to the thread and connect signals + self.worker.moveToThread(self.thread) + self._connect_worker_signals() + + # Safely start the thread and worker signal + self.thread.start() + # Emit in a thread-safe way + QtCore.QTimer.singleShot( + 0, lambda: self.worker.start_signal.emit()) + except Exception as e: + QtWidgets.QMessageBox.critical( + self, "Error", f"An unexpected error occurred: {str(e)}") + finally: + self.statusBar().removeWidget(self.progress_bar) + + def _initialize_progress_bar(self): + """Initialize the progress bar and add it to the status bar.""" + self.progress_bar.setValue(0) self.statusBar().addWidget(self.progress_bar) - self.worker = FlexibleWorker( - labelme2csv.convert_json_to_csv, str(out_folder), - progress_callback=update_progress) - self.thread = QtCore.QThread() - self.worker.moveToThread(self.thread) - self.worker.start.connect(self.worker.run) - self.worker.finished.connect(self.place_preference_analyze_auto) - self.worker.finished.connect(self.thread.quit) - self.worker.finished.connect(self.worker.deleteLater) - self.thread.finished.connect(self.thread.deleteLater) - self.worker.finished.connect(lambda: - QtWidgets.QMessageBox.about(self, - "Tracking results are ready.", - f"Kindly review the file here: {str(out_folder) + '.csv'}.")) - self.worker.progress_changed.connect(update_progress) - - self.thread.start() - self.worker.start.emit() - self.statusBar().removeWidget(self.progress_bar) + def _update_progress_bar(self, progress): + """Update the progress bar's value.""" + self.progress_bar.setValue(progress) + + def _connect_worker_signals(self): + """Connect worker signals to their respective slots safely.""" + self.worker.start_signal.connect(self.worker.run) + self.worker.finished_signal.connect(self.place_preference_analyze_auto) + + # Ensure cleanup happens in the right thread + self.worker.finished_signal.connect(self.thread.quit) + self.worker.finished_signal.connect(lambda: self.worker.deleteLater()) + self.thread.finished.connect(lambda: self.thread.deleteLater()) + + self.worker.finished_signal.connect( + lambda: QtWidgets.QMessageBox.information( + self, + "Tracking Complete", + f"Kindly review the file here: {Path(self.video_file).with_suffix('.csv')}" + ) + ) + self.worker.progress_signal.connect(self._update_progress_bar) def tracks(self): """ @@ -1821,10 +1826,10 @@ def models(self): process = start_tensorboard(log_dir=out_runs_dir) try: self.seg_train_thread.start() - train_worker = FlexibleWorker(function=segmentor.train) + train_worker = FlexibleWorker(task_function=segmentor.train) train_worker.moveToThread(self.seg_train_thread) - train_worker.start.connect(train_worker.run) - train_worker.start.emit() + train_worker.start_signal.connect(train_worker.run) + train_worker.start_signal.emit() except Exception: segmentor.train() diff --git a/annolid/gui/label_file.py b/annolid/gui/label_file.py index adef63b..06dd5ed 100644 --- a/annolid/gui/label_file.py +++ b/annolid/gui/label_file.py @@ -202,8 +202,8 @@ def save( caption=caption, ) for key, value in otherData.items(): - assert key not in data - data[key] = value + if key not in data: + data[key] = value try: with open(filename, "w") as f: json.dump(data, f, ensure_ascii=False, indent=2) diff --git a/annolid/gui/widgets/caption.py b/annolid/gui/widgets/caption.py index 1449a13..8756e7d 100644 --- a/annolid/gui/widgets/caption.py +++ b/annolid/gui/widgets/caption.py @@ -12,6 +12,7 @@ class CaptionWidget(QtWidgets.QWidget): charInserted = Signal(str) # Signal emitted when a character is inserted charDeleted = Signal(str) # Signal emitted when a character is deleted readCaptionFinished = Signal() # Define a custom signal + imageNotFound = Signal(str) def __init__(self, parent=None): super().__init__(parent) diff --git a/annolid/gui/workers.py b/annolid/gui/workers.py new file mode 100644 index 0000000..4257224 --- /dev/null +++ b/annolid/gui/workers.py @@ -0,0 +1,67 @@ +from qtpy import QtCore + + +class FlexibleWorker(QtCore.QObject): + """ + A flexible worker class that runs a given function in a separate thread. + Provides signals to indicate the start, progress, return value, and completion of the task. + """ + + start_signal = QtCore.Signal() + finished_signal = QtCore.Signal(object) + result_signal = QtCore.Signal(object) + stop_signal = QtCore.Signal() + progress_signal = QtCore.Signal(int) + + def __init__(self, task_function, *args, **kwargs): + """ + Initialize the FlexibleWorker with the function to run and its arguments. + + :param task_function: The function to be executed. + :param args: Positional arguments for the function. + :param kwargs: Keyword arguments for the function. + """ + super().__init__() + self._task_function = task_function + self._args = args + self._kwargs = kwargs + self._is_stopped = False + + # Connect the stop signal to the stop method + self.stop_signal.connect(self._stop) + + def run(self): + """ + Executes the task function with the provided arguments. + Emits signals for result and completion when done. + """ + self._is_stopped = False + try: + result = self._task_function(*self._args, **self._kwargs) + self.result_signal.emit(result) + self.finished_signal.emit(result) + except Exception as e: + # Optionally handle exceptions and emit an error signal if needed + self.finished_signal.emit(e) + + def _stop(self): + """ + Stops the worker by setting the stop flag. + """ + self._is_stopped = True + + def is_stopped(self): + """ + Check if the worker has been stopped. + + :return: True if the worker is stopped, otherwise False. + """ + return self._is_stopped + + def report_progress(self, progress): + """ + Reports the progress of the task. + + :param progress: An integer representing the progress percentage. + """ + self.progress_signal.emit(progress) diff --git a/annolid/segmentation/cutie_vos/predict.py b/annolid/segmentation/cutie_vos/predict.py index f7682f1..dffc71c 100644 --- a/annolid/segmentation/cutie_vos/predict.py +++ b/annolid/segmentation/cutie_vos/predict.py @@ -139,7 +139,7 @@ def _initialize_model(self): logger.info(f"Tmax: max_mem_frames: {self.max_mem_frames}") cutie_model = CUTIE(cfg).to(self.device).eval() model_weights = torch.load( - cfg.weights, map_location=self.device) + cfg.weights, map_location=self.device, weights_only=True) cutie_model.load_weights(model_weights) return cutie_model, cfg @@ -271,7 +271,7 @@ def commit_masks_into_permanent_memory(self, frame_number, labels_dict): dict: Updated labels dictionary. """ with torch.inference_mode(): - with torch.amp.autocast('cuda',enabled=self.cfg.amp and self.device == 'cuda'): + with torch.amp.autocast('cuda', enabled=self.cfg.amp and self.device == 'cuda'): png_file_paths = glob.glob( f"{self.video_folder}/{self.video_folder.name}_0*.png") png_file_paths = [p for p in png_file_paths if 'mask' not in p]