diff --git a/r2ai/ui/app.py b/r2ai/ui/app.py index c37912f..f7bd267 100644 --- a/r2ai/ui/app.py +++ b/r2ai/ui/app.py @@ -176,7 +176,6 @@ def on_button_pressed(self, event: Button.Pressed) -> None: if event.button.id == "send-button": self.send_message() - @work async def on_input_submitted(self, event: Input.Submitted) -> None: if event.input.id == "chat-input": await self.send_message() @@ -198,7 +197,6 @@ def on_message(self, type: str, message: any) -> None: elif type == 'tool_response': self.add_message(message["id"], "Tool Response", message['content']) - async def send_message(self) -> None: input_widget = self.query_one("#chat-input", Input) message = input_widget.value.strip() @@ -207,9 +205,13 @@ async def send_message(self) -> None: input_widget.value = "" try: await self.validate_model() - await chat(self.ai, message, self.on_message) + self.chat(message) except Exception as e: self.notify(str(e), severity="error") + + @work(thread=True) + async def chat(self, message: str) -> None: + await chat(self.ai, message, lambda type, message: self.call_from_thread(self.on_message, type, message)) async def validate_model(self) -> None: model = self.ai.model diff --git a/r2ai/ui/chat.py b/r2ai/ui/chat.py index c81e92b..cded705 100644 --- a/r2ai/ui/chat.py +++ b/r2ai/ui/chat.py @@ -22,16 +22,4 @@ async def chat(ai, message, cb): chat_auto = ChatAuto(model, interpreter=ai, system=SYSTEM_PROMPT_AUTO, tools=tools, messages=messages, tool_choice=tool_choice, cb=cb) - original_handler = signal.getsignal(signal.SIGINT) - - try: - signal.signal(signal.SIGINT, signal_handler) - return await chat_auto.achat(stream=True) - except KeyboardInterrupt: - tasks = asyncio.all_tasks() - for task in tasks: - task.cancel() - await asyncio.gather(*tasks, return_exceptions=True) - return None - finally: - signal.signal(signal.SIGINT, original_handler) \ No newline at end of file + return await chat_auto.achat(stream=True)