diff --git a/r2ai/auto.py b/r2ai/auto.py index 2e859ab..dbd417a 100644 --- a/r2ai/auto.py +++ b/r2ai/auto.py @@ -39,7 +39,7 @@ """ class ChatAuto: - def __init__(self, model, max_tokens = 1024, top_p=0.95, temperature=0.0, interpreter=None, system=None, tools=None, messages=None, tool_choice='auto', llama_instance=None, timeout=None, cb=None ): + def __init__(self, model, max_tokens = 1024, top_p=0.95, temperature=0.0, interpreter=None, system=None, tools=None, messages=None, tool_choice='auto', llama_instance=None, timeout=None, stream=True, cb=None ): self.logger = LOGGER self.functions = {} self.tools = [] @@ -52,6 +52,7 @@ def __init__(self, model, max_tokens = 1024, top_p=0.95, temperature=0.0, interp self.interpreter = interpreter self.system_message = None self.timeout = timeout + self.stream = stream if messages and messages[0]['role'] != 'system' and system: self.messages.insert(0, { "role": "system", "content": system }) if cb: @@ -66,7 +67,6 @@ def __init__(self, model, max_tokens = 1024, top_p=0.95, temperature=0.0, interp self.functions[f['name']] = tool self.tool_choice = tool_choice self.llama_instance = llama_instance or interpreter.llama_instance if interpreter else None - #self.tool_end_message = '\nNOTE: The user saw this output, do not repeat it.' async def process_tool_calls(self, tool_calls): @@ -186,31 +186,37 @@ async def async_response_generator(self, response): resp = ModelResponse(stream=True, **item) yield resp - async def attempt_completion(self, stream=True): - args = { - "temperature": self.temperature, - "top_p": self.top_p, - "max_tokens": self.max_tokens, - "stream": stream, - } - + async def attempt_completion(self): + stream = self.stream if self.llama_instance: + args = { + "temperature": self.temperature, + "top_p": self.top_p, + "max_tokens": self.max_tokens, + "stream": stream, + } res = create_chat_completion(self.interpreter, messages=self.messages, tools=[self.tools[0]], **args) if args['stream']: return self.async_response_generator(res) else: return ModelResponse(**next(res)) - + self.logger.debug('chat completion') return await acompletion( model=self.model, messages=self.messages, timeout=self.timeout, - **args + tools=self.tools, + tool_choice=self.tool_choice, + temperature=self.temperature, + top_p=self.top_p, + max_tokens=self.max_tokens, + stream=stream, ) - async def get_completion(self, stream=False): + async def get_completion(self): + stream = self.stream if self.llama_instance: - response = await self.attempt_completion(stream=stream) + response = await self.attempt_completion() if stream: return await self.process_streaming_response(response) else: @@ -220,7 +226,8 @@ async def get_completion(self, stream=False): for retry_count in range(max_retries): try: - response = await self.attempt_completion(stream=stream) + response = await self.attempt_completion() + self.logger.debug(f'chat completion {response}') if stream: return await self.process_streaming_response(response) else: @@ -236,10 +243,11 @@ async def get_completion(self, stream=False): raise Exception("Max retries reached. Unable to get completion.") - async def achat(self, messages=None, stream=False) -> str: + async def achat(self, messages=None) -> str: if messages: self.messages = messages - response = await self.get_completion(stream) + response = await self.get_completion() + self.logger.debug(f'chat complete') return response def chat(self, **kwargs) -> str: @@ -289,16 +297,26 @@ def chat(interpreter, **kwargs): try: signal.signal(signal.SIGINT, signal_handler) spinner.start() - return loop.run_until_complete(chat_auto.achat(stream=True)) + return loop.run_until_complete(chat_auto.achat()) except KeyboardInterrupt: builtins.print("\033[91m\nOperation cancelled by user.\033[0m") tasks = asyncio.all_tasks(loop=loop) for task in tasks: task.cancel() - loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) + try: + loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) + loop.run_until_complete(asyncio.sleep(0.1)) + except asyncio.CancelledError: + pass return None finally: signal.signal(signal.SIGINT, original_handler) spinner.stop() - loop.stop() - loop.close() \ No newline at end of file + try: + pending = asyncio.all_tasks(loop=loop) + for task in pending: + task.cancel() + loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) + loop.run_until_complete(loop.shutdown_asyncgens()) + finally: + loop.close() \ No newline at end of file diff --git a/r2ai/ui/chat.py b/r2ai/ui/chat.py index cded705..9e823fa 100644 --- a/r2ai/ui/chat.py +++ b/r2ai/ui/chat.py @@ -22,4 +22,4 @@ async def chat(ai, message, cb): chat_auto = ChatAuto(model, interpreter=ai, system=SYSTEM_PROMPT_AUTO, tools=tools, messages=messages, tool_choice=tool_choice, cb=cb) - return await chat_auto.achat(stream=True) + return await chat_auto.achat()