diff --git a/agentops/llms/llama_stack_client.py b/agentops/llms/llama_stack_client.py index 8218fb54b..b9ec8bd7b 100644 --- a/agentops/llms/llama_stack_client.py +++ b/agentops/llms/llama_stack_client.py @@ -23,7 +23,7 @@ def __init__(self, client): def handle_response(self, response, kwargs, init_timestamp, session: Optional[Session] = None) -> dict: """Handle responses for LlamaStack""" from llama_stack_client import LlamaStackClient - + llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs) if session is not None: llm_event.session_id = session.session_id @@ -47,8 +47,9 @@ def handle_stream_chunk(chunk: dict): llm_event.returns.delta += choice.delta if choice.event_type == "complete": - - llm_event.prompt = [{ "content": message.content, "role": message.role } for message in kwargs["messages"]] + llm_event.prompt = [ + {"content": message.content, "role": message.role} for message in kwargs["messages"] + ] llm_event.agent_id = check_call_stack_for_agent_id() llm_event.completion = accumulated_delta llm_event.prompt_tokens = None @@ -88,9 +89,9 @@ async def async_generator(): try: llm_event.returns = response - llm_event.agent_id = check_call_stack_for_agent_id() + llm_event.agent_id = check_call_stack_for_agent_id() llm_event.model = kwargs["model_id"] - llm_event.prompt = [{ "content": message.content, "role": message.role } for message in kwargs["messages"]] + llm_event.prompt = [{"content": message.content, "role": message.role} for message in kwargs["messages"]] llm_event.prompt_tokens = None llm_event.completion = response.completion_message.content llm_event.completion_tokens = None @@ -134,9 +135,7 @@ def override(self): # self._override_stream_async() def undo_override(self): - if ( - self.original_complete is not None - ): - + if self.original_complete is not None: from llama_stack_client.resources import InferenceResource + InferenceResource.chat_completion = self.original_complete diff --git a/tests/core_manual_tests/providers/llama_stack_client_canary.py b/tests/core_manual_tests/providers/llama_stack_client_canary.py index f61ac3473..0955f9ccc 100644 --- a/tests/core_manual_tests/providers/llama_stack_client_canary.py +++ b/tests/core_manual_tests/providers/llama_stack_client_canary.py @@ -11,8 +11,8 @@ agentops.init(default_tags=["llama-stack-client-provider-test"]) -host = "0.0.0.0" # LLAMA_STACK_HOST -port = 5001 # LLAMA_STACK_PORT +host = "0.0.0.0" # LLAMA_STACK_HOST +port = 5001 # LLAMA_STACK_PORT full_host = f"http://{host}:{port}" @@ -28,26 +28,28 @@ ), ], model_id="meta-llama/Llama-3.2-3B-Instruct", - stream=False + stream=False, ) + async def stream_test(): - response = client.inference.chat_completion( - messages=[ - UserMessage( - content="hello world, write me a 3 word poem about the moon", - role="user", - ), - ], - model_id="meta-llama/Llama-3.2-3B-Instruct", - stream=True - ) + response = client.inference.chat_completion( + messages=[ + UserMessage( + content="hello world, write me a 3 word poem about the moon", + role="user", + ), + ], + model_id="meta-llama/Llama-3.2-3B-Instruct", + stream=True, + ) - async for log in EventLogger().log(response): - log.print() + async for log in EventLogger().log(response): + log.print() async def main(): await stream_test() + agentops.end_session(end_state="Success")