Skip to content

Commit

Permalink
Threaded input handling and multi-session support (#31)
Browse files Browse the repository at this point in the history
* Prevent sending an input until the previous response has been handled
This would ideally use a queue but that will require using a different UI since the gradio ChatBot expects each input to return a value synchronously
Relates to #26

* Implement gradio State to track a session ID
Update handling so TTS responses are attached to a specific browser session

* Implement session-specific profile settings

* Add remaining user profile params to UI

---------

Co-authored-by: Daniel McKnight <[email protected]>
  • Loading branch information
NeonDaniel and NeonDaniel authored Nov 8, 2023
1 parent 1f13ed1 commit 37bf01f
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 40 deletions.
27 changes: 19 additions & 8 deletions neon_iris/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from typing import Optional
from uuid import uuid4
from ovos_bus_client.message import Message
from ovos_utils.json_helper import merge_dict
from pika.exceptions import StreamLostError
from neon_utils.configuration_utils import get_neon_user_config
from neon_utils.mq_utils import NeonMQHandler
Expand Down Expand Up @@ -228,27 +229,31 @@ def _clear_audio_cache():

def send_utterance(self, utterance: str, lang: str = "en-us",
username: Optional[str] = None,
user_profiles: Optional[list] = None):
user_profiles: Optional[list] = None,
context: Optional[dict] = None):
"""
Optionally override this to queue text inputs or do any pre-parsing
:param utterance: utterance to submit to skills module
:param lang: language code associated with request
:param username: username associated with request
:param user_profiles: user profiles expecting a response
:param context: Optional dict context to add to emitted message
"""
self._send_utterance(utterance, lang, username, user_profiles)
self._send_utterance(utterance, lang, username, user_profiles, context)

def send_audio(self, audio_file: str, lang: str = "en-us",
username: Optional[str] = None,
user_profiles: Optional[list] = None):
user_profiles: Optional[list] = None,
context: Optional[dict] = None):
"""
Optionally override this to queue audio inputs or do any pre-parsing
:param audio_file: path to audio file to send to speech module
:param lang: language code associated with request
:param username: username associated with request
:param user_profiles: user profiles expecting a response
:param context: Optional dict context to add to emitted message
"""
self._send_audio(audio_file, lang, username, user_profiles)
self._send_audio(audio_file, lang, username, user_profiles, context)

def _build_message(self, msg_type: str, data: dict,
username: Optional[str] = None,
Expand All @@ -267,19 +272,24 @@ def _build_message(self, msg_type: str, data: dict,
})

def _send_utterance(self, utterance: str, lang: str,
username: str, user_profiles: list):
username: str, user_profiles: list,
context: Optional[dict] = None):
context = context or dict()
username = username or self.default_username
user_profiles = user_profiles or [self.user_config]
message = self._build_message("recognizer_loop:utterance",
{"utterances": [utterance],
"lang": lang}, username, user_profiles)
serialized = {"msg_type": message.msg_type,
"data": message.data,
"context": message.context}
"context": merge_dict(message.context, context,
new_only=True)}
self._send_serialized_message(serialized)

def _send_audio(self, audio_file: str, lang: str,
username: str, user_profiles: list):
username: str, user_profiles: list,
context: Optional[dict] = None):
context = context or dict()
audio_data = encode_file_to_base64_string(audio_file)
message = self._build_message("neon.audio_input",
{"lang": lang,
Expand All @@ -289,7 +299,8 @@ def _send_audio(self, audio_file: str, lang: str,
username, user_profiles)
serialized = {"msg_type": message.msg_type,
"data": message.data,
"context": message.context}
"context": merge_dict(message.context, context,
new_only=True)}
self._send_serialized_message(serialized)

def _send_serialized_message(self, serialized: dict):
Expand Down
128 changes: 96 additions & 32 deletions neon_iris/web_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,17 @@
from os import makedirs
from os.path import isfile, join, isdir
from time import time
from typing import List, Optional
from typing import List, Optional, Dict
from uuid import uuid4

import gradio

from threading import Event
from ovos_bus_client import Message
from ovos_config import Configuration
from ovos_utils import LOG
from ovos_utils.json_helper import merge_dict

from neon_utils.file_utils import decode_base64_string_to_file
from ovos_utils.xdg_utils import xdg_data_home

Expand All @@ -50,12 +53,15 @@ def __init__(self, lang: str = None):
NeonAIClient.__init__(self, config.get("MQ"))
self._await_response = Event()
self._response = None
self._current_tts = None
self._current_tts = dict()
self._profiles: Dict[str, dict] = dict()
self._audio_path = join(xdg_data_home(), "iris", "stt")
if not isdir(self._audio_path):
makedirs(self._audio_path)
self.default_lang = lang or self.config.get('default_lang')
self.chat_ui = gradio.Blocks()
LOG.name = "iris"
LOG.init(self.config.get("logs"))

@property
def lang(self):
Expand All @@ -69,24 +75,52 @@ def supported_languages(self) -> List[str]:
"""
return self.config.get('languages') or [self.default_lang]

def update_profile(self, stt_lang: str, tts_lang: str, tts_lang_2: str):
def _start_session(self):
sid = uuid4().hex
self._current_tts[sid] = None
self._profiles[sid] = self.user_config
self._profiles[sid]['user']['username'] = sid
return sid

def update_profile(self, stt_lang: str, tts_lang: str, tts_lang_2: str,
time: int, date: str, uom: str, city: str, state: str,
country: str, first: str, middle: str, last: str,
pref_name: str, email: str, session_id: str):
"""
Callback to handle user settings changes from the web UI
"""
# TODO: Per-client config. The current method of referencing
# `self._user_config` means every user shares one configuration which
# does not scale. This client should probably override the
# `self.user_config` property and implement a method for storing user
# configuration in cookies or similar.
location_dict = dict()
if any((city, state, country)):
from neon_utils.location_utils import get_coordinates, get_timezone
try:
location_dict = {"city": city, "state": state,
"country": country}
lat, lon = get_coordinates(location_dict)
location_dict["lat"] = lat
location_dict["lng"] = lon
location_dict["tz"], location_dict["utc"] = get_timezone(lat,
lon)
LOG.debug(f"Got location update: {location_dict}")
except Exception as e:
LOG.exception(e)

profile_update = {"speech": {"stt_language": stt_lang,
"tts_language": tts_lang,
"secondary_tts_language": tts_lang_2}}
from neon_utils.user_utils import apply_local_user_profile_updates
apply_local_user_profile_updates(profile_update, self._user_config)
"secondary_tts_language": tts_lang_2},
"units": {"time": time, "date": date, "measure": uom},
"location": location_dict,
"user": {"first_name": first, "middle_name": middle,
"last_name": last,
"preferred_name": pref_name, "email": email}}
old_profile = self._profiles.get(session_id) or self.user_config
self._profiles[session_id] = merge_dict(old_profile, profile_update)
LOG.info(f"Updated profile for: {session_id}")
return session_id

def send_audio(self, audio_file: str, lang: str = "en-us",
username: Optional[str] = None,
user_profiles: Optional[list] = None):
user_profiles: Optional[list] = None,
context: Optional[dict] = None):
"""
@param audio_file: path to wav audio file to send to speech module
@param lang: language code associated with request
Expand All @@ -95,7 +129,7 @@ def send_audio(self, audio_file: str, lang: str = "en-us",
"""
# TODO: Audio conversion is really slow here. check ovos-stt-http-server
audio_file = self.convert_audio(audio_file)
self._send_audio(audio_file, lang, username, user_profiles)
self._send_audio(audio_file, lang, username, user_profiles, context)

def convert_audio(self, audio_file: str, target_sr=16000, target_channels=1,
dtype='int16') -> str:
Expand Down Expand Up @@ -128,29 +162,37 @@ def on_user_input(self, utterance: str, *args, **kwargs) -> str:
@param utterance: String utterance submitted by the user
@returns: String response from Neon (or "ERROR")
"""
# TODO: This should probably queue with a separate iterator thread
LOG.debug(f"Input received")
if not self._await_response.wait(30):
LOG.error("Previous response not completed after 30 seconds")
LOG.debug(f"args={args}|kwargs={kwargs}")
self._await_response.clear()
self._response = None
gradio_id = args[2]
if utterance:
LOG.info(f"Sending utterance: {utterance} with lang: {self.lang}")
self.send_utterance(utterance, self.lang)
self.send_utterance(utterance, self.lang, username=gradio_id,
user_profiles=[self._profiles[gradio_id]],
context={"gradio": {"session": gradio_id}})
else:
LOG.info(f"Sending audio: {args[1]} with lang: {self.lang}")
self.send_audio(args[1], self.lang)
self.send_audio(args[1], self.lang, username=gradio_id,
user_profiles=[self._profiles[gradio_id]],
context={"gradio": {"session": gradio_id}})
self._await_response.wait(30)
self._response = self._response or "ERROR"
LOG.info(f"Got response={self._response}")
return self._response

def play_tts(self):
def play_tts(self, session_id: str):
LOG.info(f"Playing most recent TTS file {self._current_tts}")
return self._current_tts
return self._current_tts.get(session_id), session_id

def run(self):
"""
Blocking method to start the web server
"""
self._await_response.set()
title = self.config.get("webui_title", "Neon AI")
description = self.config.get("webui_description", "Chat With Neon")
chatbot = self.config.get("webui_chatbot_label") or description
Expand All @@ -164,22 +206,25 @@ def run(self):
textbox = gradio.Textbox(placeholder=placeholder)

with self.chat_ui as blocks:
client_session = gradio.State(self._start_session())
client_session.attach_load_event(self._start_session, None)
# Define primary UI
audio_input = gradio.Audio(source="microphone",
type="filepath",
label=speech)
gradio.ChatInterface(self.on_user_input,
chatbot=chatbot,
textbox=textbox,
additional_inputs=[audio_input],
additional_inputs=[audio_input, client_session],
title=title,
retry_btn=None,
undo_btn=None, )
tts_audio = gradio.Audio(autoplay=True, visible=True,
label="Neon's Response")
tts_button = gradio.Button("Play TTS")
tts_button.click(self.play_tts,
outputs=[tts_audio])
inputs=[client_session],
outputs=[tts_audio, client_session])
# Define settings UI
with gradio.Row():
with gradio.Column():
Expand All @@ -193,18 +238,36 @@ def run(self):
choices=[None] +
self.supported_languages,
value=None)
submit = gradio.Button("Update User Settings")
with gradio.Column():
# TODO: Unit settings
pass
time_format = gradio.Radio(label="Time Format",
choices=[12, 24],
value=12)
date_format = gradio.Radio(label="Date Format",
choices=["MDY", "YMD", "DMY",
"YDM"],
value="MDY")
unit_of_measure = gradio.Radio(label="Units of Measure",
choices=["imperial",
"metric"],
value="imperial")
with gradio.Column():
# TODO: Location settings
pass
city = gradio.Textbox(label="City")
state = gradio.Textbox(label="State")
country = gradio.Textbox(label="Country")
with gradio.Column():
# TODO Name settings
pass
first_name = gradio.Textbox(label="First Name")
middle_name = gradio.Textbox(label="Middle Name")
last_name = gradio.Textbox(label="Last Name")
pref_name = gradio.Textbox(label="Preferred Name")
email_addr = gradio.Textbox(label="Email Address")
# TODO: DoB, pic, about, phone?
submit = gradio.Button("Update User Settings")
submit.click(self.update_profile,
inputs=[stt_lang, tts_lang, tts_lang_2])
inputs=[stt_lang, tts_lang, tts_lang_2, time_format,
date_format, unit_of_measure, city, state,
country, first_name, middle_name, last_name,
pref_name, email_addr, client_session],
outputs=[client_session])
blocks.launch(server_name=address, server_port=port)

def handle_klat_response(self, message: Message):
Expand All @@ -213,19 +276,20 @@ def handle_klat_response(self, message: Message):
audio in all requested languages.
@param message: Neon response message
"""
LOG.debug(f"Response_data={message.data}")
LOG.debug(f"gradio context={message.context['gradio']}")
resp_data = message.data["responses"]
files = []
sentences = []
session = message.context['gradio']['session']
for lang, response in resp_data.items():
sentences.append(response.get("sentence"))
if response.get("audio"):
for gender, data in response["audio"].items():
filepath = "/".join([self.audio_cache_dir] +
response[gender].split('/')[-4:])
# TODO: This only plays the most recent, so it doesn't support
# multiple languages
self._current_tts = filepath
# TODO: This only plays the most recent, so it doesn't
# support multiple languages or multi-utterance responses
self._current_tts[session] = filepath
files.append(filepath)
if not isfile(filepath):
decode_base64_string_to_file(data, filepath)
Expand Down

0 comments on commit 37bf01f

Please sign in to comment.