From 57fe969a6e614015f6a493b211df4154b5613b82 Mon Sep 17 00:00:00 2001 From: Fakhir Ali Date: Sun, 26 May 2024 14:33:55 +0500 Subject: [PATCH] Stopping in run_chat --- main.py | 2 +- openvoicechat/tts/base.py | 14 ++++++++------ openvoicechat/utils.py | 18 +++++++++++++++++- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/main.py b/main.py index 95f4404..ad476f5 100644 --- a/main.py +++ b/main.py @@ -20,4 +20,4 @@ api_key=api_key) mouth = Mouth(device=device) mouth.say_text('Good morning!') - run_chat(mouth, ear, chatbot, verbose=True) + run_chat(mouth, ear, chatbot, verbose=True, stopping_criteria=lambda x: '[END]' in x) diff --git a/openvoicechat/tts/base.py b/openvoicechat/tts/base.py index 4a67c49..997b9cb 100644 --- a/openvoicechat/tts/base.py +++ b/openvoicechat/tts/base.py @@ -109,6 +109,7 @@ def say_multiple_stream(self, text_queue: queue.Queue, ''' response = '' all_response = [] + interrupt_text_list = [] if audio_queue is None: audio_queue = queue.Queue() say_thread = threading.Thread(target=self.say, args=(audio_queue, listen_interruption_func)) @@ -116,7 +117,7 @@ def say_multiple_stream(self, text_queue: queue.Queue, while True: text = text_queue.get() if text is None: - sentence = remove_words_in_brackets_and_spaces(response).strip() + sentence = response else: response += text if bool(re.search(self.sentence_stop_pattern, response)): @@ -127,12 +128,13 @@ def say_multiple_stream(self, text_queue: queue.Queue, continue if sentence.strip() == '': break - sentence = remove_words_in_brackets_and_spaces(sentence).strip() - output = self.run_tts(sentence) - audio_queue.put((output, sentence)) + clean_sentence = remove_words_in_brackets_and_spaces(sentence).strip() + output = self.run_tts(clean_sentence) + audio_queue.put((output, clean_sentence)) all_response.append(sentence) + interrupt_text_list.append(clean_sentence) if self.interrupted: - all_response = self._handle_interruption(all_response, interrupt_queue) + all_response = self._handle_interruption(interrupt_text_list, interrupt_queue) self.interrupted = '' break if text is None: @@ -140,6 +142,6 @@ def say_multiple_stream(self, text_queue: queue.Queue, audio_queue.put((None, '')) say_thread.join() if self.interrupted: - all_response = self._handle_interruption(all_response, interrupt_queue) + all_response = self._handle_interruption(interrupt_text_list, interrupt_queue) text_queue.queue.clear() text_queue.put('. '.join(all_response)) diff --git a/openvoicechat/utils.py b/openvoicechat/utils.py index d7872cd..55f5c6f 100644 --- a/openvoicechat/utils.py +++ b/openvoicechat/utils.py @@ -5,7 +5,21 @@ -def run_chat(mouth, ear, chatbot, verbose=True): +def run_chat(mouth, ear, chatbot, verbose=True, + stopping_criteria=lambda x: False): + """ + Runs a chat session between a user and a bot. + + Parameters: mouth (object): An object responsible for the bot's speech output. ear (object): An object + responsible for listening to the user's input. chatbot (object): An object responsible for generating the bot's + responses. verbose (bool, optional): If True, prints the user's input and the bot's responses. Defaults to True. + stopping_criteria (function, optional): A function that determines when the chat should stop. It takes the bot's + response as input and returns a boolean. Defaults to a function that always returns False. + + The function works by continuously listening to the user's input and generating the bot's responses in separate + threads. If the user interrupts the bot's speech, the remaining part of the bot's response is saved and prepended + to the user's next input. The chat stops when the stopping_criteria function returns True for a bot's response. + """ pre_interruption_text = '' while True: user_input = pre_interruption_text + ' ' + ear.listen() @@ -28,6 +42,8 @@ def run_chat(mouth, ear, chatbot, verbose=True): pre_interruption_text = interrupt_queue.get() res = llm_output_queue.get() + if stopping_criteria(res): + break if verbose: print('BOT: ', res)