Skip to content

Commit

Permalink
cohere support
Browse files Browse the repository at this point in the history
  • Loading branch information
bboynton97 committed Aug 14, 2024
1 parent 9f24b7d commit 7d1eb30
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 67 deletions.
7 changes: 0 additions & 7 deletions agentops/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,10 +448,3 @@ def api_key(self):
@property
def parent_key(self):
return self._config.parent_key


def safe_record(session, event):
if session is not None:
session.record(event)
else:
Client().record(event)
8 changes: 3 additions & 5 deletions agentops/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
from packaging.version import Version, parse

from ..log_config import logger
from ..helpers import get_ISO_time
import inspect

from .cohere import CohereProvider
from .groq import GroqProvider
from .litellm import override_litellm_completion, override_litellm_async_completion
from .litellm import LiteLLMProvider
from .ollama import OllamaProvider
from .openai import OpenAiProvider

Expand Down Expand Up @@ -59,8 +57,8 @@ def override_api(self):
)

if Version(module_version) >= parse("1.3.1"):
override_litellm_completion(self)
override_litellm_async_completion(self)
provider = LiteLLMProvider(self.client)
provider.override()
else:
logger.warning(
f"Only LiteLLM>=1.3.1 supported. v{module_version} found."
Expand Down
249 changes: 194 additions & 55 deletions agentops/llms/litellm.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,212 @@
from agentops.helpers import get_ISO_time
import pprint
from typing import Optional

from ..log_config import logger
from ..event import LLMEvent, ErrorEvent
from ..session import Session
from agentops.helpers import get_ISO_time, check_call_stack_for_agent_id
from agentops.llms.instrumented_provider import InstrumentedProvider
from agentops.time_travel import fetch_completion_override_from_time_travel_cache


def override_litellm_completion(tracker):
import litellm
from openai.types.chat import (
ChatCompletion,
) # Note: litellm calls all LLM APIs using the OpenAI format

original_create = litellm.completion
class LiteLLMProvider(InstrumentedProvider):
def __init__(self, client):
super().__init__(client)

def override(self):
self._override_async_completion()
self._override_completion()

def undo_override(self):
pass

def handle_response(
self, response, kwargs, init_timestamp, session: Optional[Session] = None
) -> dict:
"""Handle responses for OpenAI versions >v1.0.0"""
from openai import AsyncStream, Stream
from openai.resources import AsyncCompletions
from openai.types.chat import ChatCompletionChunk

self.llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs)
if session is not None:
self.llm_event.session_id = session.session_id

def handle_stream_chunk(chunk: ChatCompletionChunk):
# NOTE: prompt/completion usage not returned in response when streaming
# We take the first ChatCompletionChunk and accumulate the deltas from all subsequent chunks to build one full chat completion
if self.llm_event.returns == None:
self.llm_event.returns = chunk

try:
accumulated_delta = self.llm_event.returns.choices[0].delta
self.llm_event.agent_id = check_call_stack_for_agent_id()
self.llm_event.model = chunk.model
self.llm_event.prompt = kwargs["messages"]

# NOTE: We assume for completion only choices[0] is relevant
choice = chunk.choices[0]

if choice.delta.content:
accumulated_delta.content += choice.delta.content

if choice.delta.role:
accumulated_delta.role = choice.delta.role

if choice.delta.tool_calls:
accumulated_delta.tool_calls = choice.delta.tool_calls

if choice.delta.function_call:
accumulated_delta.function_call = choice.delta.function_call

if choice.finish_reason:
# Streaming is done. Record LLMEvent
self.llm_event.returns.choices[0].finish_reason = (
choice.finish_reason
)
self.llm_event.completion = {
"role": accumulated_delta.role,
"content": accumulated_delta.content,
"function_call": accumulated_delta.function_call,
"tool_calls": accumulated_delta.tool_calls,
}
self.llm_event.end_timestamp = get_ISO_time()

self._safe_record(session, self.llm_event)
except Exception as e:
self._safe_record(
session, ErrorEvent(trigger_event=self.llm_event, exception=e)
)

kwargs_str = pprint.pformat(kwargs)
chunk = pprint.pformat(chunk)
logger.warning(
f"Unable to parse a chunk for LLM call. Skipping upload to AgentOps\n"
f"chunk:\n {chunk}\n"
f"kwargs:\n {kwargs_str}\n"
)

# if the response is a generator, decorate the generator
if isinstance(response, Stream):

def generator():
for chunk in response:
handle_stream_chunk(chunk)
yield chunk

return generator()

# For asynchronous AsyncStream
elif isinstance(response, AsyncStream):

async def async_generator():
async for chunk in response:
handle_stream_chunk(chunk)
yield chunk

return async_generator()

# For async AsyncCompletion
elif isinstance(response, AsyncCompletions):

async def async_generator():
async for chunk in response:
handle_stream_chunk(chunk)
yield chunk

return async_generator()

# v1.0.0+ responses are objects
try:
self.llm_event.returns = response
self.llm_event.agent_id = check_call_stack_for_agent_id()
self.llm_event.prompt = kwargs["messages"]
self.llm_event.prompt_tokens = response.usage.prompt_tokens
self.llm_event.completion = response.choices[0].message.model_dump()
self.llm_event.completion_tokens = response.usage.completion_tokens
self.llm_event.model = response.model

self._safe_record(session, self.llm_event)
except Exception as e:
self._safe_record(
session, ErrorEvent(trigger_event=self.llm_event, exception=e)
)

def patched_function(*args, **kwargs):
init_timestamp = get_ISO_time()
kwargs_str = pprint.pformat(kwargs)
response = pprint.pformat(response)
logger.warning(
f"Unable to parse response for LLM call. Skipping upload to AgentOps\n"
f"response:\n {response}\n"
f"kwargs:\n {kwargs_str}\n"
)

session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]
return response

completion_override = fetch_completion_override_from_time_travel_cache(kwargs)
if completion_override:
result_model = ChatCompletion.model_validate_json(completion_override)
return tracker.handle_response_v1_openai(
tracker, result_model, kwargs, init_timestamp, session=session
)
def _override_completion(self):
import litellm
from openai.types.chat import (
ChatCompletion,
) # Note: litellm calls all LLM APIs using the OpenAI format

# prompt_override = fetch_prompt_override_from_time_travel_cache(kwargs)
# if prompt_override:
# kwargs["messages"] = prompt_override["messages"]
original_create = litellm.completion

# Call the original function with its original arguments
result = original_create(*args, **kwargs)
return tracker.handle_response_v1_openai(
tracker, result, kwargs, init_timestamp, session=session
)
def patched_function(*args, **kwargs):
init_timestamp = get_ISO_time()

litellm.completion = patched_function
session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]

completion_override = fetch_completion_override_from_time_travel_cache(
kwargs
)
if completion_override:
result_model = ChatCompletion.model_validate_json(completion_override)
return self.handle_response(
result_model, kwargs, init_timestamp, session=session
)

def override_litellm_async_completion(tracker):
import litellm
from openai.types.chat import (
ChatCompletion,
) # Note: litellm calls all LLM APIs using the OpenAI format
# prompt_override = fetch_prompt_override_from_time_travel_cache(kwargs)
# if prompt_override:
# kwargs["messages"] = prompt_override["messages"]

original_create_async = litellm.acompletion
# Call the original function with its original arguments
result = original_create(*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)

async def patched_function(*args, **kwargs):
init_timestamp = get_ISO_time()
litellm.completion = patched_function

session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]
def _override_async_completion(self):
import litellm
from openai.types.chat import (
ChatCompletion,
) # Note: litellm calls all LLM APIs using the OpenAI format

completion_override = fetch_completion_override_from_time_travel_cache(kwargs)
if completion_override:
result_model = ChatCompletion.model_validate_json(completion_override)
return tracker.handle_response_v1_openai(
tracker, result_model, kwargs, init_timestamp, session=session
)
original_create_async = litellm.acompletion

# prompt_override = fetch_prompt_override_from_time_travel_cache(kwargs)
# if prompt_override:
# kwargs["messages"] = prompt_override["messages"]
async def patched_function(*args, **kwargs):
init_timestamp = get_ISO_time()

# Call the original function with its original arguments
result = await original_create_async(*args, **kwargs)
return tracker.handle_response_v1_openai(
tracker, result, kwargs, init_timestamp, session=session
)
session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]

# Override the original method with the patched one
litellm.acompletion = patched_function
completion_override = fetch_completion_override_from_time_travel_cache(
kwargs
)
if completion_override:
result_model = ChatCompletion.model_validate_json(completion_override)
return self.handle_response(
result_model, kwargs, init_timestamp, session=session
)

# prompt_override = fetch_prompt_override_from_time_travel_cache(kwargs)
# if prompt_override:
# kwargs["messages"] = prompt_override["messages"]

# Call the original function with its original arguments
result = await original_create_async(*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)

# Override the original method with the patched one
litellm.acompletion = patched_function
16 changes: 16 additions & 0 deletions tests/core_manual_tests/providers/litellm_canary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import agentops
from dotenv import load_dotenv
import litellm

load_dotenv()
agentops.init(default_tags=["litellm-provider-test"])

response = litellm.completion(
model="gpt-3.5-turbo", messages=[{"content": "Hello, how are you?", "role": "user"}]
)

agentops.end_session(end_state="Success")

###
# Used to verify that one session is created with one LLM event
###

0 comments on commit 7d1eb30

Please sign in to comment.