Skip to content

Commit

Permalink
refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
the-praxs committed Aug 17, 2024
1 parent d4ff3db commit e67aecc
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 0 deletions.
18 changes: 18 additions & 0 deletions agentops/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .litellm import LiteLLMProvider
from .ollama import OllamaProvider
from .openai import OpenAiProvider
from .mistral import MistralProvider

original_func = {}
original_create = None
Expand All @@ -35,6 +36,9 @@ class LlmTracker:
"groq": {
"0.9.0": ("Client.chat", "AsyncClient.chat"),
},
"mistralai": {
"1.0.1": ("chat.complete", "chat.stream"),
},
}

def __init__(self, client):
Expand Down Expand Up @@ -116,6 +120,17 @@ def override_api(self):
f"Only Groq>=0.9.0 supported. v{module_version} found."
)

if api == "mistralai":
module_version = version(api)

if Version(module_version) >= parse("1.0.1"):
provider = MistralProvider(self.client)
provider.override()
else:
logger.warning(
f"Only MistralAI>=1.0.1 supported. v{module_version} found."
)

def stop_instrumenting(self):
openai_provider = OpenAiProvider(self.client)
openai_provider.undo_override()
Expand All @@ -125,3 +140,6 @@ def stop_instrumenting(self):

cohere_provider = CohereProvider(self.client)
cohere_provider.undo_override()

mistral_provider = MistralProvider(self.client)
mistral_provider.undo_override()
216 changes: 216 additions & 0 deletions agentops/llms/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import inspect
import pprint
import sys
from typing import Optional

from ..event import LLMEvent, ErrorEvent
from ..session import Session
from ..log_config import logger
from agentops.helpers import get_ISO_time, check_call_stack_for_agent_id
from .instrumented_provider import InstrumentedProvider

from mistralai import Chat

class MistralProvider(InstrumentedProvider):

original_complete = None
original_complete_async = None
original_stream = None
original_stream_async = None

def __init__(self, client):
super().__init__(client)
self._provider_name = "Mistral"

def handle_response(
self, response, kwargs, init_timestamp, session: Optional[Session] = None
) -> dict:
"""Handle responses for Mistral"""

self.llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs)
if session is not None:
self.llm_event.session_id = session.session_id

def handle_stream_chunk(chunk: dict):
# NOTE: prompt/completion usage not returned in response when streaming
# We take the first ChatCompletionChunk and accumulate the deltas from all subsequent chunks to build one full chat completion
if self.llm_event.returns == None:
self.llm_event.returns = chunk.data

try:
accumulated_delta = self.llm_event.returns.choices[0].delta
self.llm_event.agent_id = check_call_stack_for_agent_id()
self.llm_event.model = chunk.data.model
self.llm_event.prompt = kwargs["messages"]

# NOTE: We assume for completion only choices[0] is relevant
choice = chunk.data.choices[0]

if choice.delta.content:
accumulated_delta.content += choice.delta.content

if choice.delta.role:
accumulated_delta.role = choice.delta.role

if choice.delta.tool_calls:
accumulated_delta.tool_calls = choice.delta.tool_calls

if chunk.data.choices[0].finish_reason:
# Streaming is done. Record LLMEvent
self.llm_event.returns.choices[0].finish_reason = (
choice.finish_reason
)
self.llm_event.complettion = {
"role": accumulated_delta.role,
"content": accumulated_delta.content,
"tool_calls": accumulated_delta.tool_calls,
}
self.llm_event.prompt_tokens = chunk.data.usage.prompt_tokens
self.llm_event.completion_tokens = chunk.data.usage.completion_tokens
self.llm_event.end_timestamp = get_ISO_time()
self._safe_record(session, self.llm_event)

except Exception as e:
self._safe_record(
session, ErrorEvent(trigger_event=self.llm_event, exception=e)
)

kwargs_str = pprint.pformat(kwargs)
chunk = pprint.pformat(chunk)
logger.warning(
f"Unable to parse a chunk for LLM call. Skipping upload to AgentOps\n"
f"chunk:\n {chunk}\n"
f"kwargs:\n {kwargs_str}\n"
)

# if the response is a generator, decorate the generator
if inspect.isasyncgen(response):

async def async_generator():
async for chunk in response:
handle_stream_chunk(chunk)
yield chunk

return async_generator()

elif inspect.isgenerator(response):

def generator():
for chunk in response:
handle_stream_chunk(chunk)
yield chunk

return generator()

try:
self.llm_event.returns = response
self.llm_event.agent_id = check_call_stack_for_agent_id()
self.llm_event.model = response.model
self.llm_event.prompt = kwargs["messages"]
self.llm_event.prompt_tokens = response.usage.prompt_tokens
self.llm_event.completion = response.choices[0].message.model_dump()
self.llm_event.completion_tokens = response.usage.completion_tokens
self.llm_event.end_timestamp = get_ISO_time()

self._safe_record(session, self.llm_event)
except Exception as e:
self._safe_record(
session, ErrorEvent(trigger_event=self.llm_event, exception=e)
)
kwargs_str = pprint.pformat(kwargs)
response = pprint.pformat(response)
logger.warning(
f"Unable to parse response for LLM call. Skipping upload to AgentOps\n"
f"response:\n {response}\n"
f"kwargs:\n {kwargs_str}\n"
)

return response

def _override_complete(self):
global original_chat
original_chat = Chat.complete

def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()
session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]
result = original_chat(*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)

# Override the original method with the patched one
Chat.complete = patched_function

def _override_complete_async(self):
global original_chat_async
original_chat_async = Chat.complete_async

async def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()
session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]
result = await original_chat_async(*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)

# Override the original method with the patched one
Chat.complete_async = patched_function

def _override_stream(self):
global original_stream
original_stream = Chat.stream

def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()
session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]
result = original_stream(*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)

# Override the original method with the patched one
Chat.stream = patched_function

def _override_stream_async(self):
global original_stream_async
original_stream_async = Chat.stream_async

async def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()
session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]
result = await original_stream_async(*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)

# Override the original method with the patched one
Chat.stream_async = patched_function

def override(self):
self._override_complete()
self._override_complete_async()
self._override_stream()
self._override_stream_async()

def _undo_override_complete(self):
Chat.complete = self.original_complete

def _undo_override_complete_async(self):
Chat.complete_async = self.original_complete_async

def _undo_override_stream(self):
Chat.stream = self.original_stream

def _undo_override_stream_async(self):
Chat.stream_async = self.original_stream_async

def undo_override(self):
self._undo_override_complete()
self._undo_override_complete_async()
self._undo_override_stream()
self._undo_override_stream_async()

0 comments on commit e67aecc

Please sign in to comment.