-
Notifications
You must be signed in to change notification settings - Fork 243
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9f24b7d
commit 7d1eb30
Showing
4 changed files
with
213 additions
and
67 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,73 +1,212 @@ | ||
from agentops.helpers import get_ISO_time | ||
import pprint | ||
from typing import Optional | ||
|
||
from ..log_config import logger | ||
from ..event import LLMEvent, ErrorEvent | ||
from ..session import Session | ||
from agentops.helpers import get_ISO_time, check_call_stack_for_agent_id | ||
from agentops.llms.instrumented_provider import InstrumentedProvider | ||
from agentops.time_travel import fetch_completion_override_from_time_travel_cache | ||
|
||
|
||
def override_litellm_completion(tracker): | ||
import litellm | ||
from openai.types.chat import ( | ||
ChatCompletion, | ||
) # Note: litellm calls all LLM APIs using the OpenAI format | ||
|
||
original_create = litellm.completion | ||
class LiteLLMProvider(InstrumentedProvider): | ||
def __init__(self, client): | ||
super().__init__(client) | ||
|
||
def override(self): | ||
self._override_async_completion() | ||
self._override_completion() | ||
|
||
def undo_override(self): | ||
pass | ||
|
||
def handle_response( | ||
self, response, kwargs, init_timestamp, session: Optional[Session] = None | ||
) -> dict: | ||
"""Handle responses for OpenAI versions >v1.0.0""" | ||
from openai import AsyncStream, Stream | ||
from openai.resources import AsyncCompletions | ||
from openai.types.chat import ChatCompletionChunk | ||
|
||
self.llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs) | ||
if session is not None: | ||
self.llm_event.session_id = session.session_id | ||
|
||
def handle_stream_chunk(chunk: ChatCompletionChunk): | ||
# NOTE: prompt/completion usage not returned in response when streaming | ||
# We take the first ChatCompletionChunk and accumulate the deltas from all subsequent chunks to build one full chat completion | ||
if self.llm_event.returns == None: | ||
self.llm_event.returns = chunk | ||
|
||
try: | ||
accumulated_delta = self.llm_event.returns.choices[0].delta | ||
self.llm_event.agent_id = check_call_stack_for_agent_id() | ||
self.llm_event.model = chunk.model | ||
self.llm_event.prompt = kwargs["messages"] | ||
|
||
# NOTE: We assume for completion only choices[0] is relevant | ||
choice = chunk.choices[0] | ||
|
||
if choice.delta.content: | ||
accumulated_delta.content += choice.delta.content | ||
|
||
if choice.delta.role: | ||
accumulated_delta.role = choice.delta.role | ||
|
||
if choice.delta.tool_calls: | ||
accumulated_delta.tool_calls = choice.delta.tool_calls | ||
|
||
if choice.delta.function_call: | ||
accumulated_delta.function_call = choice.delta.function_call | ||
|
||
if choice.finish_reason: | ||
# Streaming is done. Record LLMEvent | ||
self.llm_event.returns.choices[0].finish_reason = ( | ||
choice.finish_reason | ||
) | ||
self.llm_event.completion = { | ||
"role": accumulated_delta.role, | ||
"content": accumulated_delta.content, | ||
"function_call": accumulated_delta.function_call, | ||
"tool_calls": accumulated_delta.tool_calls, | ||
} | ||
self.llm_event.end_timestamp = get_ISO_time() | ||
|
||
self._safe_record(session, self.llm_event) | ||
except Exception as e: | ||
self._safe_record( | ||
session, ErrorEvent(trigger_event=self.llm_event, exception=e) | ||
) | ||
|
||
kwargs_str = pprint.pformat(kwargs) | ||
chunk = pprint.pformat(chunk) | ||
logger.warning( | ||
f"Unable to parse a chunk for LLM call. Skipping upload to AgentOps\n" | ||
f"chunk:\n {chunk}\n" | ||
f"kwargs:\n {kwargs_str}\n" | ||
) | ||
|
||
# if the response is a generator, decorate the generator | ||
if isinstance(response, Stream): | ||
|
||
def generator(): | ||
for chunk in response: | ||
handle_stream_chunk(chunk) | ||
yield chunk | ||
|
||
return generator() | ||
|
||
# For asynchronous AsyncStream | ||
elif isinstance(response, AsyncStream): | ||
|
||
async def async_generator(): | ||
async for chunk in response: | ||
handle_stream_chunk(chunk) | ||
yield chunk | ||
|
||
return async_generator() | ||
|
||
# For async AsyncCompletion | ||
elif isinstance(response, AsyncCompletions): | ||
|
||
async def async_generator(): | ||
async for chunk in response: | ||
handle_stream_chunk(chunk) | ||
yield chunk | ||
|
||
return async_generator() | ||
|
||
# v1.0.0+ responses are objects | ||
try: | ||
self.llm_event.returns = response | ||
self.llm_event.agent_id = check_call_stack_for_agent_id() | ||
self.llm_event.prompt = kwargs["messages"] | ||
self.llm_event.prompt_tokens = response.usage.prompt_tokens | ||
self.llm_event.completion = response.choices[0].message.model_dump() | ||
self.llm_event.completion_tokens = response.usage.completion_tokens | ||
self.llm_event.model = response.model | ||
|
||
self._safe_record(session, self.llm_event) | ||
except Exception as e: | ||
self._safe_record( | ||
session, ErrorEvent(trigger_event=self.llm_event, exception=e) | ||
) | ||
|
||
def patched_function(*args, **kwargs): | ||
init_timestamp = get_ISO_time() | ||
kwargs_str = pprint.pformat(kwargs) | ||
response = pprint.pformat(response) | ||
logger.warning( | ||
f"Unable to parse response for LLM call. Skipping upload to AgentOps\n" | ||
f"response:\n {response}\n" | ||
f"kwargs:\n {kwargs_str}\n" | ||
) | ||
|
||
session = kwargs.get("session", None) | ||
if "session" in kwargs.keys(): | ||
del kwargs["session"] | ||
return response | ||
|
||
completion_override = fetch_completion_override_from_time_travel_cache(kwargs) | ||
if completion_override: | ||
result_model = ChatCompletion.model_validate_json(completion_override) | ||
return tracker.handle_response_v1_openai( | ||
tracker, result_model, kwargs, init_timestamp, session=session | ||
) | ||
def _override_completion(self): | ||
import litellm | ||
from openai.types.chat import ( | ||
ChatCompletion, | ||
) # Note: litellm calls all LLM APIs using the OpenAI format | ||
|
||
# prompt_override = fetch_prompt_override_from_time_travel_cache(kwargs) | ||
# if prompt_override: | ||
# kwargs["messages"] = prompt_override["messages"] | ||
original_create = litellm.completion | ||
|
||
# Call the original function with its original arguments | ||
result = original_create(*args, **kwargs) | ||
return tracker.handle_response_v1_openai( | ||
tracker, result, kwargs, init_timestamp, session=session | ||
) | ||
def patched_function(*args, **kwargs): | ||
init_timestamp = get_ISO_time() | ||
|
||
litellm.completion = patched_function | ||
session = kwargs.get("session", None) | ||
if "session" in kwargs.keys(): | ||
del kwargs["session"] | ||
|
||
completion_override = fetch_completion_override_from_time_travel_cache( | ||
kwargs | ||
) | ||
if completion_override: | ||
result_model = ChatCompletion.model_validate_json(completion_override) | ||
return self.handle_response( | ||
result_model, kwargs, init_timestamp, session=session | ||
) | ||
|
||
def override_litellm_async_completion(tracker): | ||
import litellm | ||
from openai.types.chat import ( | ||
ChatCompletion, | ||
) # Note: litellm calls all LLM APIs using the OpenAI format | ||
# prompt_override = fetch_prompt_override_from_time_travel_cache(kwargs) | ||
# if prompt_override: | ||
# kwargs["messages"] = prompt_override["messages"] | ||
|
||
original_create_async = litellm.acompletion | ||
# Call the original function with its original arguments | ||
result = original_create(*args, **kwargs) | ||
return self.handle_response(result, kwargs, init_timestamp, session=session) | ||
|
||
async def patched_function(*args, **kwargs): | ||
init_timestamp = get_ISO_time() | ||
litellm.completion = patched_function | ||
|
||
session = kwargs.get("session", None) | ||
if "session" in kwargs.keys(): | ||
del kwargs["session"] | ||
def _override_async_completion(self): | ||
import litellm | ||
from openai.types.chat import ( | ||
ChatCompletion, | ||
) # Note: litellm calls all LLM APIs using the OpenAI format | ||
|
||
completion_override = fetch_completion_override_from_time_travel_cache(kwargs) | ||
if completion_override: | ||
result_model = ChatCompletion.model_validate_json(completion_override) | ||
return tracker.handle_response_v1_openai( | ||
tracker, result_model, kwargs, init_timestamp, session=session | ||
) | ||
original_create_async = litellm.acompletion | ||
|
||
# prompt_override = fetch_prompt_override_from_time_travel_cache(kwargs) | ||
# if prompt_override: | ||
# kwargs["messages"] = prompt_override["messages"] | ||
async def patched_function(*args, **kwargs): | ||
init_timestamp = get_ISO_time() | ||
|
||
# Call the original function with its original arguments | ||
result = await original_create_async(*args, **kwargs) | ||
return tracker.handle_response_v1_openai( | ||
tracker, result, kwargs, init_timestamp, session=session | ||
) | ||
session = kwargs.get("session", None) | ||
if "session" in kwargs.keys(): | ||
del kwargs["session"] | ||
|
||
# Override the original method with the patched one | ||
litellm.acompletion = patched_function | ||
completion_override = fetch_completion_override_from_time_travel_cache( | ||
kwargs | ||
) | ||
if completion_override: | ||
result_model = ChatCompletion.model_validate_json(completion_override) | ||
return self.handle_response( | ||
result_model, kwargs, init_timestamp, session=session | ||
) | ||
|
||
# prompt_override = fetch_prompt_override_from_time_travel_cache(kwargs) | ||
# if prompt_override: | ||
# kwargs["messages"] = prompt_override["messages"] | ||
|
||
# Call the original function with its original arguments | ||
result = await original_create_async(*args, **kwargs) | ||
return self.handle_response(result, kwargs, init_timestamp, session=session) | ||
|
||
# Override the original method with the patched one | ||
litellm.acompletion = patched_function |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import agentops | ||
from dotenv import load_dotenv | ||
import litellm | ||
|
||
load_dotenv() | ||
agentops.init(default_tags=["litellm-provider-test"]) | ||
|
||
response = litellm.completion( | ||
model="gpt-3.5-turbo", messages=[{"content": "Hello, how are you?", "role": "user"}] | ||
) | ||
|
||
agentops.end_session(end_state="Success") | ||
|
||
### | ||
# Used to verify that one session is created with one LLM event | ||
### |