From 11bff4e188666e3ca0b06d6475fcb714d8b39ce7 Mon Sep 17 00:00:00 2001 From: healthonrails Date: Sat, 9 Nov 2024 14:32:04 -0500 Subject: [PATCH] Add support for chatting with Ollama models for video frame or image analysis --- annolid/gui/app.py | 22 +++++-- annolid/gui/widgets/caption.py | 113 +++++++++++++++++++++++++++++---- 2 files changed, 119 insertions(+), 16 deletions(-) diff --git a/annolid/gui/app.py b/annolid/gui/app.py index f59c988..8dc858d 100644 --- a/annolid/gui/app.py +++ b/annolid/gui/app.py @@ -1311,7 +1311,6 @@ def format_shape(s): flags=flags, caption=self.canvas.getCaption(), ) - logger.info(f"Saved image and label json file: {filename}") if has_zone_shapes: self.zone_path = filename @@ -1356,6 +1355,10 @@ def _saveFile(self, filename): self.addRecentFile(filename) label_file = self._getLabelFile(filename) self._addItem(image_filename, label_file) + + if self.caption_widget is not None: + self.caption_widget.set_image_path(image_filename) + self.setClean() def getLabelFile(self): @@ -1410,6 +1413,9 @@ def setDirty(self): def getTitle(self, clean=True): title = __appname__ + if self.caption_widget is None: + self.openCaption() + self.caption_widget.set_image_path(self.filename) _filename = os.path.basename(self.filename) if self.video_loader: if self.frame_number: @@ -2398,12 +2404,16 @@ def loadShapes(self, shapes, replace=True): self.labelList.clearSelection() self._noSelectionSlot = False self.canvas.loadShapes(shapes, replace=replace) - caption = self.labelFile.get_caption() if self.labelFile else None + try: + caption = self.labelFile.get_caption() if self.labelFile else None + except AttributeError: + caption = None if caption is not None: if self.caption_widget is None: self.openCaption() - self.caption_widget.set_caption( - caption) # Update caption widget + self.caption_widget.set_caption( + caption) # Update caption widget + self.caption_widget.set_image_path(self.filename) def loadPredictShapes(self, frame_number, filename): if self.caption_widget is not None: @@ -2500,6 +2510,8 @@ def openNextImg(self, _value=False, load=True): if self.filename and load: self.loadFile(self.filename) + if self.caption_widget is not None: + self.caption_widget.set_image_path(self.filename) self._config["keep_prev"] = keep_prev @@ -2533,6 +2545,8 @@ def openPrevImg(self, _value=False): filename = self.imageList[currIndex - 1] if filename: self.loadFile(filename) + if self.caption_widget is not None: + self.caption_widget.set_image_path(self.filename) self._config["keep_prev"] = keep_prev diff --git a/annolid/gui/widgets/caption.py b/annolid/gui/widgets/caption.py index a7f479b..1449a13 100644 --- a/annolid/gui/widgets/caption.py +++ b/annolid/gui/widgets/caption.py @@ -1,5 +1,5 @@ from qtpy import QtWidgets, QtGui, QtCore -from qtpy.QtWidgets import QVBoxLayout, QTextEdit, QPushButton, QLabel, QHBoxLayout +from qtpy.QtWidgets import QVBoxLayout, QTextEdit, QPushButton, QLabel, QHBoxLayout, QLineEdit from qtpy.QtCore import Signal, Qt, QRunnable, QThreadPool, QMetaObject import threading import os @@ -60,13 +60,6 @@ def init_ui(self): describe_button_layout.addWidget( self.describe_label, alignment=Qt.AlignCenter) - # Add the save caption button - self.save_button = self.create_button( - icon_name="document-save", # Adjust icon as needed - color="#66b3ff", - hover_color="#3399ff" - ) - # Add the clear caption button self.clear_button = self.create_button( icon_name="edit-clear", # Adjust icon as needed @@ -126,9 +119,6 @@ def init_ui(self): # (Add read button layout to the main button layout) button_layout.addLayout(clear_button_layout) - # Add the button layout to the main layout - self.layout.addLayout(button_layout) - # Connect describe button signal self.describe_button.clicked.connect(self.on_describe_clicked) @@ -139,8 +129,62 @@ def init_ui(self): # Connect the signal to the slot self.readCaptionFinished.connect(self.on_read_caption_finished) + # Horizontal layout for the prompt text edit and chat button + self.input_layout = QtWidgets.QHBoxLayout() + + # Prompt text editor for user input + self.prompt_text_edit = QtWidgets.QLineEdit(self) + self.prompt_text_edit.setPlaceholderText( + "Type your chat prompt here...") + self.input_layout.addWidget(self.prompt_text_edit) + + # Chat button + self.chat_button = QtWidgets.QPushButton("Chat", self) + self.chat_button.clicked.connect(self.chat_with_ollama) + self.input_layout.addWidget(self.chat_button) + + # Add the input layout to the main layout + self.layout.addLayout(self.input_layout) + + # Add the button layout to the main layout + self.layout.addLayout(button_layout) + + # Integrate existing layouts self.setLayout(self.layout) + def chat_with_ollama(self): + """Initiates a chat with the Ollama model and displays chat history.""" + user_input = self.prompt_text_edit.text() + if not user_input: + print("No input provided for chat.") + return + + # Append user's input to the chat history + self.append_to_chat_history(f"User: {user_input}") + + # Update UI to indicate that a chat is in progress + self.chat_button.setEnabled(False) + + # Start the chat task + task = ChatWithOllamaTask(user_input, self.image_path, self) + self.thread_pool.start(task) + self.prompt_text_edit.clear() + + def append_to_chat_history(self, message): + """Appends a message to the chat history display.""" + self.text_edit.append(message + '\n') + + @QtCore.Slot(str, bool) + def update_chat_response(self, message, is_error): + """Handles the chat response.""" + if is_error: + self.text_edit.append("\nError: " + message) + else: + self.text_edit.append("\nOllama: " + message) + + # Reset UI + self.chat_button.setEnabled(True) + def create_button(self, icon_name, color, hover_color): """Creates and returns a styled button.""" button = QPushButton() @@ -475,7 +519,8 @@ def run(self): "Unexpected response format: 'message' or 'content' key missing.") except Exception as e: - error_message = f"Error describing image: {e}" + error_message = f"An error occurred while describing the image: {e}.\n" + error_message += "Please save the video frame to disk by clicking the 'Save' button or pressing Ctrl/Cmd + S." QtCore.QMetaObject.invokeMethod( self.widget, "update_description_status", QtCore.Qt.QueuedConnection, QtCore.Q_ARG(str, error_message), @@ -493,3 +538,47 @@ def __init__(self, widget): def run(self): """Runs the read_caption method in the background.""" self.widget.read_caption() + + +class ChatWithOllamaTask(QRunnable): + """A task to chat with the Ollama model in the background.""" + + def __init__(self, prompt, image_path=None, widget=None): + super().__init__() + self.prompt = prompt + self.image_path = image_path + self.widget = widget + + def run(self): + """Sends a chat message to Ollama and processes the response.""" + try: + import ollama + + messages = [{'role': 'user', 'content': self.prompt}] + if self.image_path: + # Attach the image if provided + messages[0]['images'] = [self.image_path] + + response = ollama.chat( + model='llama3.2-vision', + messages=messages, + ) + + # Check and handle the response + if "message" in response and "content" in response["message"]: + response_content = response["message"]["content"] + QtCore.QMetaObject.invokeMethod( + self.widget, "update_chat_response", QtCore.Qt.QueuedConnection, + QtCore.Q_ARG(str, response_content), + QtCore.Q_ARG(bool, False) + ) + else: + raise ValueError("Unexpected response format from Ollama.") + + except Exception as e: + error_message = f"Error in chat interaction: {e}" + QtCore.QMetaObject.invokeMethod( + self.widget, "update_chat_response", QtCore.Qt.QueuedConnection, + QtCore.Q_ARG(str, error_message), + QtCore.Q_ARG(bool, True) + )