diff --git a/agentops/client.py b/agentops/client.py index 1add700f..9ee587f4 100644 --- a/agentops/client.py +++ b/agentops/client.py @@ -448,10 +448,3 @@ def api_key(self): @property def parent_key(self): return self._config.parent_key - - -def safe_record(session, event): - if session is not None: - session.record(event) - else: - Client().record(event) diff --git a/agentops/llms/__init__.py b/agentops/llms/__init__.py index fa7df2c9..f2615c3e 100644 --- a/agentops/llms/__init__.py +++ b/agentops/llms/__init__.py @@ -6,12 +6,10 @@ from packaging.version import Version, parse from ..log_config import logger -from ..helpers import get_ISO_time -import inspect from .cohere import CohereProvider from .groq import GroqProvider -from .litellm import override_litellm_completion, override_litellm_async_completion +from .litellm import LiteLLMProvider from .ollama import OllamaProvider from .openai import OpenAiProvider @@ -59,8 +57,8 @@ def override_api(self): ) if Version(module_version) >= parse("1.3.1"): - override_litellm_completion(self) - override_litellm_async_completion(self) + provider = LiteLLMProvider(self.client) + provider.override() else: logger.warning( f"Only LiteLLM>=1.3.1 supported. v{module_version} found." diff --git a/agentops/llms/litellm.py b/agentops/llms/litellm.py index c22b2a1d..8fa31255 100644 --- a/agentops/llms/litellm.py +++ b/agentops/llms/litellm.py @@ -1,73 +1,212 @@ -from agentops.helpers import get_ISO_time +import pprint +from typing import Optional + +from ..log_config import logger +from ..event import LLMEvent, ErrorEvent +from ..session import Session +from agentops.helpers import get_ISO_time, check_call_stack_for_agent_id +from agentops.llms.instrumented_provider import InstrumentedProvider from agentops.time_travel import fetch_completion_override_from_time_travel_cache -def override_litellm_completion(tracker): - import litellm - from openai.types.chat import ( - ChatCompletion, - ) # Note: litellm calls all LLM APIs using the OpenAI format - - original_create = litellm.completion +class LiteLLMProvider(InstrumentedProvider): + def __init__(self, client): + super().__init__(client) + + def override(self): + self._override_async_completion() + self._override_completion() + + def undo_override(self): + pass + + def handle_response( + self, response, kwargs, init_timestamp, session: Optional[Session] = None + ) -> dict: + """Handle responses for OpenAI versions >v1.0.0""" + from openai import AsyncStream, Stream + from openai.resources import AsyncCompletions + from openai.types.chat import ChatCompletionChunk + + self.llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs) + if session is not None: + self.llm_event.session_id = session.session_id + + def handle_stream_chunk(chunk: ChatCompletionChunk): + # NOTE: prompt/completion usage not returned in response when streaming + # We take the first ChatCompletionChunk and accumulate the deltas from all subsequent chunks to build one full chat completion + if self.llm_event.returns == None: + self.llm_event.returns = chunk + + try: + accumulated_delta = self.llm_event.returns.choices[0].delta + self.llm_event.agent_id = check_call_stack_for_agent_id() + self.llm_event.model = chunk.model + self.llm_event.prompt = kwargs["messages"] + + # NOTE: We assume for completion only choices[0] is relevant + choice = chunk.choices[0] + + if choice.delta.content: + accumulated_delta.content += choice.delta.content + + if choice.delta.role: + accumulated_delta.role = choice.delta.role + + if choice.delta.tool_calls: + accumulated_delta.tool_calls = choice.delta.tool_calls + + if choice.delta.function_call: + accumulated_delta.function_call = choice.delta.function_call + + if choice.finish_reason: + # Streaming is done. Record LLMEvent + self.llm_event.returns.choices[0].finish_reason = ( + choice.finish_reason + ) + self.llm_event.completion = { + "role": accumulated_delta.role, + "content": accumulated_delta.content, + "function_call": accumulated_delta.function_call, + "tool_calls": accumulated_delta.tool_calls, + } + self.llm_event.end_timestamp = get_ISO_time() + + self._safe_record(session, self.llm_event) + except Exception as e: + self._safe_record( + session, ErrorEvent(trigger_event=self.llm_event, exception=e) + ) + + kwargs_str = pprint.pformat(kwargs) + chunk = pprint.pformat(chunk) + logger.warning( + f"Unable to parse a chunk for LLM call. Skipping upload to AgentOps\n" + f"chunk:\n {chunk}\n" + f"kwargs:\n {kwargs_str}\n" + ) + + # if the response is a generator, decorate the generator + if isinstance(response, Stream): + + def generator(): + for chunk in response: + handle_stream_chunk(chunk) + yield chunk + + return generator() + + # For asynchronous AsyncStream + elif isinstance(response, AsyncStream): + + async def async_generator(): + async for chunk in response: + handle_stream_chunk(chunk) + yield chunk + + return async_generator() + + # For async AsyncCompletion + elif isinstance(response, AsyncCompletions): + + async def async_generator(): + async for chunk in response: + handle_stream_chunk(chunk) + yield chunk + + return async_generator() + + # v1.0.0+ responses are objects + try: + self.llm_event.returns = response + self.llm_event.agent_id = check_call_stack_for_agent_id() + self.llm_event.prompt = kwargs["messages"] + self.llm_event.prompt_tokens = response.usage.prompt_tokens + self.llm_event.completion = response.choices[0].message.model_dump() + self.llm_event.completion_tokens = response.usage.completion_tokens + self.llm_event.model = response.model + + self._safe_record(session, self.llm_event) + except Exception as e: + self._safe_record( + session, ErrorEvent(trigger_event=self.llm_event, exception=e) + ) - def patched_function(*args, **kwargs): - init_timestamp = get_ISO_time() + kwargs_str = pprint.pformat(kwargs) + response = pprint.pformat(response) + logger.warning( + f"Unable to parse response for LLM call. Skipping upload to AgentOps\n" + f"response:\n {response}\n" + f"kwargs:\n {kwargs_str}\n" + ) - session = kwargs.get("session", None) - if "session" in kwargs.keys(): - del kwargs["session"] + return response - completion_override = fetch_completion_override_from_time_travel_cache(kwargs) - if completion_override: - result_model = ChatCompletion.model_validate_json(completion_override) - return tracker.handle_response_v1_openai( - tracker, result_model, kwargs, init_timestamp, session=session - ) + def _override_completion(self): + import litellm + from openai.types.chat import ( + ChatCompletion, + ) # Note: litellm calls all LLM APIs using the OpenAI format - # prompt_override = fetch_prompt_override_from_time_travel_cache(kwargs) - # if prompt_override: - # kwargs["messages"] = prompt_override["messages"] + original_create = litellm.completion - # Call the original function with its original arguments - result = original_create(*args, **kwargs) - return tracker.handle_response_v1_openai( - tracker, result, kwargs, init_timestamp, session=session - ) + def patched_function(*args, **kwargs): + init_timestamp = get_ISO_time() - litellm.completion = patched_function + session = kwargs.get("session", None) + if "session" in kwargs.keys(): + del kwargs["session"] + completion_override = fetch_completion_override_from_time_travel_cache( + kwargs + ) + if completion_override: + result_model = ChatCompletion.model_validate_json(completion_override) + return self.handle_response( + result_model, kwargs, init_timestamp, session=session + ) -def override_litellm_async_completion(tracker): - import litellm - from openai.types.chat import ( - ChatCompletion, - ) # Note: litellm calls all LLM APIs using the OpenAI format + # prompt_override = fetch_prompt_override_from_time_travel_cache(kwargs) + # if prompt_override: + # kwargs["messages"] = prompt_override["messages"] - original_create_async = litellm.acompletion + # Call the original function with its original arguments + result = original_create(*args, **kwargs) + return self.handle_response(result, kwargs, init_timestamp, session=session) - async def patched_function(*args, **kwargs): - init_timestamp = get_ISO_time() + litellm.completion = patched_function - session = kwargs.get("session", None) - if "session" in kwargs.keys(): - del kwargs["session"] + def _override_async_completion(self): + import litellm + from openai.types.chat import ( + ChatCompletion, + ) # Note: litellm calls all LLM APIs using the OpenAI format - completion_override = fetch_completion_override_from_time_travel_cache(kwargs) - if completion_override: - result_model = ChatCompletion.model_validate_json(completion_override) - return tracker.handle_response_v1_openai( - tracker, result_model, kwargs, init_timestamp, session=session - ) + original_create_async = litellm.acompletion - # prompt_override = fetch_prompt_override_from_time_travel_cache(kwargs) - # if prompt_override: - # kwargs["messages"] = prompt_override["messages"] + async def patched_function(*args, **kwargs): + init_timestamp = get_ISO_time() - # Call the original function with its original arguments - result = await original_create_async(*args, **kwargs) - return tracker.handle_response_v1_openai( - tracker, result, kwargs, init_timestamp, session=session - ) + session = kwargs.get("session", None) + if "session" in kwargs.keys(): + del kwargs["session"] - # Override the original method with the patched one - litellm.acompletion = patched_function + completion_override = fetch_completion_override_from_time_travel_cache( + kwargs + ) + if completion_override: + result_model = ChatCompletion.model_validate_json(completion_override) + return self.handle_response( + result_model, kwargs, init_timestamp, session=session + ) + + # prompt_override = fetch_prompt_override_from_time_travel_cache(kwargs) + # if prompt_override: + # kwargs["messages"] = prompt_override["messages"] + + # Call the original function with its original arguments + result = await original_create_async(*args, **kwargs) + return self.handle_response(result, kwargs, init_timestamp, session=session) + + # Override the original method with the patched one + litellm.acompletion = patched_function diff --git a/tests/core_manual_tests/providers/litellm_canary.py b/tests/core_manual_tests/providers/litellm_canary.py new file mode 100644 index 00000000..b97ce646 --- /dev/null +++ b/tests/core_manual_tests/providers/litellm_canary.py @@ -0,0 +1,16 @@ +import agentops +from dotenv import load_dotenv +import litellm + +load_dotenv() +agentops.init(default_tags=["litellm-provider-test"]) + +response = litellm.completion( + model="gpt-3.5-turbo", messages=[{"content": "Hello, how are you?", "role": "user"}] +) + +agentops.end_session(end_state="Success") + +### +# Used to verify that one session is created with one LLM event +###