diff --git a/docker_overlay/etc/neon/neon.yaml b/docker_overlay/etc/neon/neon.yaml index dea30dc..c65b600 100644 --- a/docker_overlay/etc/neon/neon.yaml +++ b/docker_overlay/etc/neon/neon.yaml @@ -9,6 +9,9 @@ iris: webui_title: Neon AI webui_description: Chat with Neon webui_input_placeholder: Ask me something + webui_chatbot_label: Chat History + webui_mic_label: Speak to Neon + webui_text_label: Text with Neon server_address: "0.0.0.0" server_port: 7860 default_lang: en-us diff --git a/neon_iris/web_client.py b/neon_iris/web_client.py index 480a832..fd9f991 100644 --- a/neon_iris/web_client.py +++ b/neon_iris/web_client.py @@ -27,7 +27,7 @@ from os import makedirs from os.path import isfile, join, isdir from time import time -from typing import List, Dict +from typing import List, Dict, Tuple from uuid import uuid4 import gradio @@ -42,8 +42,6 @@ from ovos_utils.xdg_utils import xdg_data_home from neon_iris.client import NeonAIClient -import librosa -import soundfile as sf class GradIOClient(NeonAIClient): @@ -53,6 +51,7 @@ def __init__(self, lang: str = None): NeonAIClient.__init__(self, config.get("MQ")) self._await_response = Event() self._response = None + self._transcribed = None self._current_tts = dict() self._profiles: Dict[str, dict] = dict() self._audio_path = join(xdg_data_home(), "iris", "stt") @@ -118,21 +117,24 @@ def update_profile(self, stt_lang: str, tts_lang: str, tts_lang_2: str, LOG.info(f"Updated profile for: {session_id}") return session_id - def on_user_input(self, utterance: str, *args, **kwargs) -> str: + def on_user_input(self, utterance: str, + chat_history: List[Tuple[str, str]], + audio_input: str, + client_session: str) -> (List[Tuple[str, str]], str, str, None): """ Callback to handle textual user input @param utterance: String utterance submitted by the user - @returns: String response from Neon (or "ERROR") + @returns: Input box contents, Updated chat history, Gradio session ID """ input_time = time() LOG.debug(f"Input received") if not self._await_response.wait(30): LOG.error("Previous response not completed after 30 seconds") - LOG.debug(f"args={args}|kwargs={kwargs}") in_queue = time() - input_time self._await_response.clear() self._response = None - gradio_id = args[2] + self._transcribed = None + gradio_id = client_session lang = self.get_lang(gradio_id) if utterance: LOG.info(f"Sending utterance: {utterance} with lang: {lang}") @@ -142,8 +144,8 @@ def on_user_input(self, utterance: str, *args, **kwargs) -> str: "timing": {"wait_in_queue": in_queue, "gradio_sent": time()}}) else: - LOG.info(f"Sending audio: {args[1]} with lang: {lang}") - self.send_audio(args[1], lang, username=gradio_id, + LOG.info(f"Sending audio: {audio_input} with lang: {lang}") + self.send_audio(audio_input, lang, username=gradio_id, user_profiles=[self._profiles[gradio_id]], context={"gradio": {"session": gradio_id}, "timing": {"wait_in_queue": in_queue, @@ -153,7 +155,12 @@ def on_user_input(self, utterance: str, *args, **kwargs) -> str: self._await_response.set() self._response = self._response or "ERROR" LOG.info(f"Got response={self._response}") - return self._response + if utterance: + chat_history.append((utterance, self._response)) + elif self._transcribed: + LOG.info(f"Got transcript: {self._transcribed}") + chat_history.append((self._transcribed, self._response)) + return chat_history, gradio_id, "", None def play_tts(self, session_id: str): LOG.info(f"Playing most recent TTS file {self._current_tts}") @@ -166,36 +173,48 @@ def run(self): self._await_response.set() title = self.config.get("webui_title", "Neon AI") description = self.config.get("webui_description", "Chat With Neon") - chatbot = self.config.get("webui_chatbot_label") or description + chatbot_label = self.config.get("webui_chatbot_label") or description speech = self.config.get("webui_mic_label") or description + text_label = self.config.get("webui_text_label") or description placeholder = self.config.get("webui_input_placeholder", "Ask me something") address = self.config.get("server_address") or "0.0.0.0" port = self.config.get("server_port") or 7860 - chatbot = gradio.Chatbot(label=chatbot) - textbox = gradio.Textbox(placeholder=placeholder) - with self.chat_ui as blocks: client_session = gradio.State(self._start_session()) client_session.attach_load_event(self._start_session, None) # Define primary UI - audio_input = gradio.Audio(source="microphone", - type="filepath", - label=speech) - gradio.ChatInterface(self.on_user_input, - chatbot=chatbot, - textbox=textbox, - additional_inputs=[audio_input, client_session], - title=title, - retry_btn=None, - undo_btn=None, ) - tts_audio = gradio.Audio(autoplay=True, visible=True, - label="Neon's Response") - tts_button = gradio.Button("Play TTS") - tts_button.click(self.play_tts, - inputs=[client_session], - outputs=[tts_audio, client_session]) + blocks.title = title + chatbot = gradio.Chatbot(label=chatbot_label) + with gradio.Row(): + textbox = gradio.Textbox(label=text_label, + placeholder=placeholder, + scale=8) + audio_input = gradio.Audio(source="microphone", + type="filepath", + label=speech, + scale=2) + submit = gradio.Button(value="Submit", + variant="primary") + submit.click(self.on_user_input, + inputs=[textbox, chatbot, audio_input, + client_session], + outputs=[chatbot, client_session, textbox, + audio_input]) + textbox.submit(self.on_user_input, + inputs=[textbox, chatbot, audio_input, + client_session], + outputs=[chatbot, client_session, textbox, + audio_input]) + with gradio.Row(): + tts_audio = gradio.Audio(autoplay=True, visible=True, + label="Neon's Response", + scale=10) + tts_button = gradio.Button("Play TTS") + tts_button.click(self.play_tts, + inputs=[client_session], + outputs=[tts_audio, client_session]) # Define settings UI with gradio.Row(): with gradio.Column(): @@ -208,7 +227,7 @@ def run(self): value=lang) tts_lang_2 = gradio.Radio(label="Second Response Language", choices=[None] + - self.supported_languages, + self.supported_languages, value=None) with gradio.Column(): time_format = gradio.Radio(label="Time Format", @@ -284,6 +303,8 @@ def handle_api_response(self, message: Message): @param message: Response message to something emitted by this client """ LOG.debug(f"Got {message.msg_type}: {message.data}") + if message.msg_type == "neon.audio_input.response": + self._transcribed = message.data.get("transcripts", [""])[0] def _handle_profile_update(self, message: Message): updated_profile = message.data["profile"]