Skip to content

Commit

Permalink
Refactor GradIO UI to use Chatbot class directly
Browse files Browse the repository at this point in the history
Relocate audio input box next to `Submit` button
Clear audio input upon response to input
Move `Play TTS` button to visually match line above
Updates default web UI labels in Docker config
Closes #30
Closes #29
  • Loading branch information
NeonDaniel committed Nov 21, 2023
1 parent 9d62d03 commit 385dab9
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 31 deletions.
3 changes: 3 additions & 0 deletions docker_overlay/etc/neon/neon.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ iris:
webui_title: Neon AI
webui_description: Chat with Neon
webui_input_placeholder: Ask me something
webui_chatbot_label: Chat History
webui_mic_label: Speak to Neon
webui_text_label: Text with Neon
server_address: "0.0.0.0"
server_port: 7860
default_lang: en-us
Expand Down
83 changes: 52 additions & 31 deletions neon_iris/web_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from os import makedirs
from os.path import isfile, join, isdir
from time import time
from typing import List, Dict
from typing import List, Dict, Tuple
from uuid import uuid4

import gradio
Expand All @@ -42,8 +42,6 @@
from ovos_utils.xdg_utils import xdg_data_home

from neon_iris.client import NeonAIClient
import librosa
import soundfile as sf


class GradIOClient(NeonAIClient):
Expand All @@ -53,6 +51,7 @@ def __init__(self, lang: str = None):
NeonAIClient.__init__(self, config.get("MQ"))
self._await_response = Event()
self._response = None
self._transcribed = None
self._current_tts = dict()
self._profiles: Dict[str, dict] = dict()
self._audio_path = join(xdg_data_home(), "iris", "stt")
Expand Down Expand Up @@ -118,21 +117,24 @@ def update_profile(self, stt_lang: str, tts_lang: str, tts_lang_2: str,
LOG.info(f"Updated profile for: {session_id}")
return session_id

def on_user_input(self, utterance: str, *args, **kwargs) -> str:
def on_user_input(self, utterance: str,
chat_history: List[Tuple[str, str]],
audio_input: str,
client_session: str) -> (List[Tuple[str, str]], str, str, None):
"""
Callback to handle textual user input
@param utterance: String utterance submitted by the user
@returns: String response from Neon (or "ERROR")
@returns: Input box contents, Updated chat history, Gradio session ID
"""
input_time = time()
LOG.debug(f"Input received")
if not self._await_response.wait(30):
LOG.error("Previous response not completed after 30 seconds")
LOG.debug(f"args={args}|kwargs={kwargs}")
in_queue = time() - input_time
self._await_response.clear()
self._response = None
gradio_id = args[2]
self._transcribed = None
gradio_id = client_session
lang = self.get_lang(gradio_id)
if utterance:
LOG.info(f"Sending utterance: {utterance} with lang: {lang}")
Expand All @@ -142,8 +144,8 @@ def on_user_input(self, utterance: str, *args, **kwargs) -> str:
"timing": {"wait_in_queue": in_queue,
"gradio_sent": time()}})
else:
LOG.info(f"Sending audio: {args[1]} with lang: {lang}")
self.send_audio(args[1], lang, username=gradio_id,
LOG.info(f"Sending audio: {audio_input} with lang: {lang}")
self.send_audio(audio_input, lang, username=gradio_id,
user_profiles=[self._profiles[gradio_id]],
context={"gradio": {"session": gradio_id},
"timing": {"wait_in_queue": in_queue,
Expand All @@ -153,7 +155,12 @@ def on_user_input(self, utterance: str, *args, **kwargs) -> str:
self._await_response.set()
self._response = self._response or "ERROR"
LOG.info(f"Got response={self._response}")
return self._response
if utterance:
chat_history.append((utterance, self._response))
elif self._transcribed:
LOG.info(f"Got transcript: {self._transcribed}")
chat_history.append((self._transcribed, self._response))
return chat_history, gradio_id, "", None

def play_tts(self, session_id: str):
LOG.info(f"Playing most recent TTS file {self._current_tts}")
Expand All @@ -166,36 +173,48 @@ def run(self):
self._await_response.set()
title = self.config.get("webui_title", "Neon AI")
description = self.config.get("webui_description", "Chat With Neon")
chatbot = self.config.get("webui_chatbot_label") or description
chatbot_label = self.config.get("webui_chatbot_label") or description
speech = self.config.get("webui_mic_label") or description
text_label = self.config.get("webui_text_label") or description
placeholder = self.config.get("webui_input_placeholder",
"Ask me something")
address = self.config.get("server_address") or "0.0.0.0"
port = self.config.get("server_port") or 7860

chatbot = gradio.Chatbot(label=chatbot)
textbox = gradio.Textbox(placeholder=placeholder)

with self.chat_ui as blocks:
client_session = gradio.State(self._start_session())
client_session.attach_load_event(self._start_session, None)
# Define primary UI
audio_input = gradio.Audio(source="microphone",
type="filepath",
label=speech)
gradio.ChatInterface(self.on_user_input,
chatbot=chatbot,
textbox=textbox,
additional_inputs=[audio_input, client_session],
title=title,
retry_btn=None,
undo_btn=None, )
tts_audio = gradio.Audio(autoplay=True, visible=True,
label="Neon's Response")
tts_button = gradio.Button("Play TTS")
tts_button.click(self.play_tts,
inputs=[client_session],
outputs=[tts_audio, client_session])
blocks.title = title
chatbot = gradio.Chatbot(label=chatbot_label)
with gradio.Row():
textbox = gradio.Textbox(label=text_label,
placeholder=placeholder,
scale=8)
audio_input = gradio.Audio(source="microphone",
type="filepath",
label=speech,
scale=2)
submit = gradio.Button(value="Submit",
variant="primary")
submit.click(self.on_user_input,
inputs=[textbox, chatbot, audio_input,
client_session],
outputs=[chatbot, client_session, textbox,
audio_input])
textbox.submit(self.on_user_input,
inputs=[textbox, chatbot, audio_input,
client_session],
outputs=[chatbot, client_session, textbox,
audio_input])
with gradio.Row():
tts_audio = gradio.Audio(autoplay=True, visible=True,
label="Neon's Response",
scale=10)
tts_button = gradio.Button("Play TTS")
tts_button.click(self.play_tts,
inputs=[client_session],
outputs=[tts_audio, client_session])
# Define settings UI
with gradio.Row():
with gradio.Column():
Expand All @@ -208,7 +227,7 @@ def run(self):
value=lang)
tts_lang_2 = gradio.Radio(label="Second Response Language",
choices=[None] +
self.supported_languages,
self.supported_languages,
value=None)
with gradio.Column():
time_format = gradio.Radio(label="Time Format",
Expand Down Expand Up @@ -284,6 +303,8 @@ def handle_api_response(self, message: Message):
@param message: Response message to something emitted by this client
"""
LOG.debug(f"Got {message.msg_type}: {message.data}")
if message.msg_type == "neon.audio_input.response":
self._transcribed = message.data.get("transcripts", [""])[0]

def _handle_profile_update(self, message: Message):
updated_profile = message.data["profile"]
Expand Down

0 comments on commit 385dab9

Please sign in to comment.