diff --git a/annolid/gui/app.py b/annolid/gui/app.py index 01f27d3..f59c988 100644 --- a/annolid/gui/app.py +++ b/annolid/gui/app.py @@ -1,75 +1,75 @@ +from labelme.ai import MODELS +from annolid.gui.widgets.caption import CaptionWidget +from annolid.data.videos import get_video_files +from annolid.gui.widgets.place_preference_dialog import TrackingAnalyzerDialog +from annolid.gui.widgets.advanced_parameters_dialog import AdvancedParametersDialog +from annolid.annotation import labelme2csv +from annolid.segmentation.SAM.edge_sam_bg import VideoProcessor +from annolid.annotation.timestamps import convert_frame_number_to_time +from annolid.postprocessing.quality_control import pred_dict_to_labelme +from annolid.gui.widgets.convert_labelme2csv_dialog import LabelmeJsonToCsvDialog +from annolid.gui.widgets.extract_keypoints_dialog import ExtractShapeKeyPointsDialog +from annolid.gui.widgets.convert_sleap_dialog import ConvertSleapDialog +from annolid.gui.widgets.downsample_videos_dialog import VideoRescaleWidget +from annolid.gui.widgets.step_size_widget import StepSizeWidget +from annolid.gui.widgets.video_slider import VideoSlider, VideoSliderMark +import qimage2ndarray +import atexit +import webbrowser +from annolid.gui.widgets import ProgressingWindow +from annolid.postprocessing.quality_control import TracksResults +from annolid.postprocessing.glitter import tracks2nix +from annolid.gui.widgets import SystemInfoDialog +from annolid.gui.widgets import TrackDialog +from annolid.gui.widgets import QualityControlDialog +from annolid.gui.widgets import Glitter2Dialog +from annolid.gui.widgets import TrainModelDialog +from annolid.gui.widgets import ConvertCOODialog +from annolid.gui.widgets import ExtractFrameDialog +from annolid.data import videos +from annolid.annotation import labelme2coco +from annolid.gui.widgets.text_prompt import AiRectangleWidget +from annolid.gui.widgets.canvas import Canvas +from annolid.configs import get_config +from annolid.gui.label_file import LabelFile +from annolid.gui.label_file import LabelFileError +from labelme.widgets import ToolBar +from annolid.utils.files import count_json_files +from annolid.utils.logger import logger +from labelme import utils +from labelme.widgets import LabelListWidgetItem +from labelme.widgets import BrightnessContrastDialog +from labelme.utils import newAction +from labelme.app import MainWindow +from annolid.gui.shape import Shape +import subprocess +import requests +from PIL import ImageQt +from labelme import QT5 +from labelme import PY2 +from qtpy import QtGui +from qtpy import QtWidgets +from qtpy.QtCore import Qt +from qtpy import QtCore +import functools +from pathlib import Path +import argparse +import imgviz +import codecs +import torch +from collections import deque +import numpy as np +import pandas as pd +import shutil +import html +import time +import os.path as osp +import csv +import re import sys import os # Enable CPU fallback for unsupported MPS ops os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" -import re -import csv -import os.path as osp -import time -import html -import shutil -import pandas as pd -import numpy as np -from collections import deque -import torch -import codecs -import imgviz -import argparse -from pathlib import Path -import functools -from qtpy import QtCore -from qtpy.QtCore import Qt -from qtpy import QtWidgets -from qtpy import QtGui -from labelme import PY2 -from labelme import QT5 -from PIL import ImageQt -import requests -import subprocess -from annolid.gui.shape import Shape -from labelme.app import MainWindow -from labelme.utils import newAction -from labelme.widgets import BrightnessContrastDialog -from labelme.widgets import LabelListWidgetItem -from labelme import utils -from annolid.utils.logger import logger -from annolid.utils.files import count_json_files -from labelme.widgets import ToolBar -from annolid.gui.label_file import LabelFileError -from annolid.gui.label_file import LabelFile -from annolid.configs import get_config -from annolid.gui.widgets.canvas import Canvas -from annolid.gui.widgets.text_prompt import AiRectangleWidget -from annolid.annotation import labelme2coco -from annolid.data import videos -from annolid.gui.widgets import ExtractFrameDialog -from annolid.gui.widgets import ConvertCOODialog -from annolid.gui.widgets import TrainModelDialog -from annolid.gui.widgets import Glitter2Dialog -from annolid.gui.widgets import QualityControlDialog -from annolid.gui.widgets import TrackDialog -from annolid.gui.widgets import SystemInfoDialog -from annolid.postprocessing.glitter import tracks2nix -from annolid.postprocessing.quality_control import TracksResults -from annolid.gui.widgets import ProgressingWindow -import webbrowser -import atexit -import qimage2ndarray -from annolid.gui.widgets.video_slider import VideoSlider, VideoSliderMark -from annolid.gui.widgets.step_size_widget import StepSizeWidget -from annolid.gui.widgets.downsample_videos_dialog import VideoRescaleWidget -from annolid.gui.widgets.convert_sleap_dialog import ConvertSleapDialog -from annolid.gui.widgets.extract_keypoints_dialog import ExtractShapeKeyPointsDialog -from annolid.gui.widgets.convert_labelme2csv_dialog import LabelmeJsonToCsvDialog -from annolid.postprocessing.quality_control import pred_dict_to_labelme -from annolid.annotation.timestamps import convert_frame_number_to_time -from annolid.segmentation.SAM.edge_sam_bg import VideoProcessor -from annolid.annotation import labelme2csv -from annolid.gui.widgets.advanced_parameters_dialog import AdvancedParametersDialog -from annolid.gui.widgets.place_preference_dialog import TrackingAnalyzerDialog -from annolid.data.videos import get_video_files -from annolid.gui.widgets.caption import CaptionWidget -from labelme.ai import MODELS __appname__ = 'Annolid' __version__ = "1.2.1" @@ -1980,6 +1980,8 @@ def set_frame_number(self, frame_number): f"{str(self.video_results_folder.name)}_{self.frame_number:09}.png" self.current_frame_time_stamp = self.video_loader.get_time_stamp() self.frame_loader.request(frame_number) + if self.caption_widget is not None: + self.caption_widget.set_image_path(self.filename) def load_tracking_results(self, cur_video_folder, video_filename): """Load tracking results from CSV files in the given folder that match the video filename.""" @@ -2261,7 +2263,10 @@ def image_to_canvas(self, qimage, filename, frame_number): _event_key = (frame_number, 'event_end') _, _state = _event_key if _event_key in self.timestamp_dict: - timestamp, behaivor, subject, trial_time = self.timestamp_dict[_event_key] + try: + timestamp, behaivor, subject, trial_time = self.timestamp_dict[_event_key] + except ValueError: + behaivor = "Others" if _state != 'event_end': flags[behaivor] = True else: @@ -2401,6 +2406,8 @@ def loadShapes(self, shapes, replace=True): caption) # Update caption widget def loadPredictShapes(self, frame_number, filename): + if self.caption_widget is not None: + self.caption_widget.set_image_path(filename) label_json_file = str(filename).replace(".png", ".json") # try to load json files generated by SAM2 like 000000000.json diff --git a/annolid/gui/widgets/caption.py b/annolid/gui/widgets/caption.py index 791fca1..ed2a061 100644 --- a/annolid/gui/widgets/caption.py +++ b/annolid/gui/widgets/caption.py @@ -1,6 +1,6 @@ from qtpy import QtWidgets, QtGui, QtCore -from qtpy.QtWidgets import QVBoxLayout, QTextEdit, QPushButton, QLabel -from qtpy.QtCore import Signal, Qt +from qtpy.QtWidgets import QVBoxLayout, QTextEdit, QPushButton, QLabel, QHBoxLayout +from qtpy.QtCore import Signal, Qt, QRunnable, QThreadPool import threading @@ -12,48 +12,114 @@ class CaptionWidget(QtWidgets.QWidget): def __init__(self, parent=None): super().__init__(parent) + self.init_ui() + self.previous_text = "" + self.image_path = "" + self.is_recording = False + self.thread_pool = QThreadPool() # Thread pool for running background tasks + + def init_ui(self): + """Initializes the UI components.""" self.layout = QVBoxLayout(self) # Create a QTextEdit for editing captions self.text_edit = QTextEdit() self.layout.addWidget(self.text_edit) - # Create a circular button with a microphone icon for recording - self.record_button = QPushButton() - self.record_button.setFixedSize(50, 50) # Make it a circle - self.record_button.setIcon(QtGui.QIcon.fromTheme("microphone")) - self.record_button.setIconSize(QtCore.QSize(30, 30)) - self.record_button.setStyleSheet(""" - QPushButton { - border: none; - background-color: #ff4d4d; - border-radius: 25px; - } - QPushButton:hover { - background-color: #e60000; - } - """) - self.layout.addWidget(self.record_button, alignment=Qt.AlignCenter) + # Create a horizontal layout for the buttons and labels + button_layout = QHBoxLayout() - # Add a label below the record button to display the recording status + # Create the record button with its label below + self.record_button = self.create_button( + icon_name="microphone", + color="#ff4d4d", + hover_color="#e60000" + ) self.record_label = QLabel("Tap to record") self.record_label.setAlignment(Qt.AlignCenter) - self.layout.addWidget(self.record_label) + record_button_layout = QVBoxLayout() + record_button_layout.addWidget( + self.record_button, alignment=Qt.AlignCenter) + record_button_layout.addWidget( + self.record_label, alignment=Qt.AlignCenter) - # Store previous text for comparison - self.previous_text = "" + # Create the describe button with its label below + self.describe_button = self.create_button( + icon_name="view-preview", # Adjust this icon as needed + color="#4d94ff", + hover_color="#0040ff" + ) + self.describe_label = QLabel("Describe the image") + self.describe_label.setAlignment(Qt.AlignCenter) + describe_button_layout = QVBoxLayout() + describe_button_layout.addWidget( + self.describe_button, alignment=Qt.AlignCenter) + describe_button_layout.addWidget( + self.describe_label, alignment=Qt.AlignCenter) + + # Add the save caption button + self.save_button = self.create_button( + icon_name="document-save", # Adjust icon as needed + color="#66b3ff", + hover_color="#3399ff" + ) + + # Add the clear caption button + self.clear_button = self.create_button( + icon_name="edit-clear", # Adjust icon as needed + color="#ffcc99", + hover_color="#ff9900" + ) + self.clear_label = QLabel("Clear caption") + self.clear_label.setAlignment(Qt.AlignCenter) + clear_button_layout = QVBoxLayout() + clear_button_layout.addWidget( + self.clear_button, alignment=Qt.AlignCenter) + clear_button_layout.addWidget( + self.clear_label, alignment=Qt.AlignCenter) + + # Connect the buttons to their respective methods + self.clear_button.clicked.connect(self.clear_caption) + + # Add both button layouts to the horizontal layout + button_layout.addLayout(record_button_layout) + button_layout.addLayout(describe_button_layout) + # Add the new button layouts to the main button layout + button_layout.addLayout(clear_button_layout) + + # Add the button layout to the main layout + self.layout.addLayout(button_layout) + + # Connect describe button signal + self.describe_button.clicked.connect(self.on_describe_clicked) # Connect signals and slots self.text_edit.textChanged.connect(self.emit_caption_changed) self.text_edit.textChanged.connect(self.monitor_text_change) self.record_button.clicked.connect(self.toggle_recording) - # Initialize recording state - self.is_recording = False self.setLayout(self.layout) + def create_button(self, icon_name, color, hover_color): + """Creates and returns a styled button.""" + button = QPushButton() + button.setFixedSize(20, 20) # Smaller size + button.setIcon(QtGui.QIcon.fromTheme(icon_name)) + button.setIconSize(QtCore.QSize(10, 10)) + button.setStyleSheet(f""" + QPushButton {{ + border: none; + background-color: {color}; + border-radius: 20px; + }} + QPushButton:hover {{ + background-color: {hover_color}; + }} + """) + return button + def monitor_text_change(self): - """Monitor for text changes and emit signals for insertions and deletions.""" + """Monitors text changes and emits appropriate signals.""" current_text = self.text_edit.toPlainText() if self.previous_text == "": self.previous_text = current_text @@ -74,6 +140,43 @@ def set_caption(self, caption_text): self.text_edit.setPlainText(caption_text) self.previous_text = caption_text + def clear_caption(self): + """Clears the caption.""" + self.set_caption("") + self.clear_label.setText("Clear caption") + self.clear_label.setStyleSheet("color: red;") + + def set_image_path(self, image_path): + """Sets the image path.""" + self.image_path = image_path + + def get_image_path(self): + """Returns the image path.""" + return self.image_path + + def on_describe_clicked(self): + """Handles the Describe button click and starts the background task.""" + image_path = self.get_image_path() + if image_path: + self.describe_label.setText("Describing image...") + self.describe_label.setStyleSheet("color: blue;") + task = DescribeImageTask(image_path, self) + self.thread_pool.start(task) + else: + self.text_edit.setPlainText("No image selected for description.") + + @QtCore.Slot(str, bool) + def update_description_status(self, message, is_error): + """Updates the description status in the UI.""" + if is_error: + self.text_edit.setPlainText(message) + self.describe_label.setText("Description failed.") + self.describe_label.setStyleSheet("color: red;") + else: + self.text_edit.setPlainText(message) + self.describe_label.setText("Describe the image") + self.describe_label.setStyleSheet("color: green;") + def emit_caption_changed(self): """Emits the captionChanged signal with the current caption.""" self.captionChanged.emit(self.text_edit.toPlainText()) @@ -83,54 +186,43 @@ def get_caption(self): return self.text_edit.toPlainText() def toggle_recording(self): - """Toggles the recording state and starts or stops recording.""" + """Toggles the recording state and updates the UI.""" if not self.is_recording: - # Start recording - self.is_recording = True - self.record_button.setStyleSheet(""" - QPushButton { - border: none; - background-color: red; - border-radius: 25px; - } - """) - self.record_label.setText("Recording...") - self.record_label.setStyleSheet("color: red; font-size: 16px;") - threading.Thread(target=self.record_voice, daemon=True).start() + self.start_recording() else: - # Stop recording - self.is_recording = False - self.record_button.setStyleSheet(""" - QPushButton { - border: none; - background-color: #ff4d4d; - border-radius: 25px; - } - """) - self.record_label.setText("Tap to record") - self.record_label.setStyleSheet("color: black;") + self.stop_recording() - def toggle_recording(self): - """Toggles the recording state and starts or stops recording.""" - if not self.is_recording: - # Start recording - self.is_recording = True - self.record_button.setStyleSheet(""" - QPushButton { - border: none; - background-color: red; - border-radius: 25px; - } - """) - self.record_label.setText("Recording...") - self.record_label.setStyleSheet("color: red; font-size: 16px;") - threading.Thread(target=self.record_voice, daemon=True).start() - else: - # Stop recording and show "Converting speech to text..." - self.is_recording = False - self.record_label.setText("Converting speech to text...") - self.record_label.setStyleSheet("color: blue; font-size: 16px;") - # No need to change button appearance here, it will reset after processing + def start_recording(self): + """Starts the recording process and updates UI accordingly.""" + self.is_recording = True + self.record_button.setStyleSheet(""" + QPushButton { + border: none; + background-color: red; + border-radius: 20px; + } + """) + self.record_label.setText("Recording...") + self.record_label.setStyleSheet("color: red; font-size: 16px;") + threading.Thread(target=self.record_voice, daemon=True).start() + + def stop_recording(self): + """Stops recording and displays conversion status.""" + self.is_recording = False + self.record_label.setText("Converting speech to text...") + self.record_label.setStyleSheet("color: blue; font-size: 16px;") + + def stop_recording_ui_reset(self): + """Resets the UI after recording ends.""" + self.record_button.setStyleSheet(""" + QPushButton { + border: none; + background-color: #ff4d4d; + border-radius: 20px; + } + """) + self.record_label.setText("Tap to record") + self.record_label.setStyleSheet("color: black;") def record_voice(self): """Records voice input and converts it to text continuously until stopped.""" @@ -194,3 +286,45 @@ def record_voice(self): """) self.record_label.setText("Tap to record") self.record_label.setStyleSheet("color: black;") + + +class DescribeImageTask(QRunnable): + """A task to describe an image in the background.""" + + def __init__(self, image_path, widget): + super().__init__() + self.image_path = image_path + self.widget = widget + + def run(self): + """Runs the task in the background.""" + try: + import ollama + response = ollama.chat( + model='llama3.2-vision', + messages=[{ + 'role': 'user', + 'content': 'Describe this image in detail.', + 'images': [self.image_path] + }] + ) + + # Access response content safely + if "message" in response and "content" in response["message"]: + description = response["message"]["content"] + QtCore.QMetaObject.invokeMethod( + self.widget, "update_description_status", QtCore.Qt.QueuedConnection, + QtCore.Q_ARG(str, description), + QtCore.Q_ARG(bool, False) + ) + else: + raise ValueError( + "Unexpected response format: 'message' or 'content' key missing.") + + except Exception as e: + error_message = f"Error describing image: {e}" + QtCore.QMetaObject.invokeMethod( + self.widget, "update_description_status", QtCore.Qt.QueuedConnection, + QtCore.Q_ARG(str, error_message), + QtCore.Q_ARG(bool, True) + )