Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Add support for AI21 via the ai21 Python SDK #382

Merged
merged 27 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
3c2b4a9
add initial files for support
the-praxs Sep 10, 2024
0a6fd15
Merge branch 'main' into ai21-ops
the-praxs Sep 18, 2024
85ee097
working sync client
the-praxs Sep 19, 2024
765a0fa
stream not working
the-praxs Sep 19, 2024
dde2581
updated examples notebook for current testing
the-praxs Sep 19, 2024
a31ab24
fix for `delta.content` and cleanup
the-praxs Sep 19, 2024
bddc1e5
Merge branch 'main' into ai21-ops
the-praxs Sep 19, 2024
4626505
cleanup again
the-praxs Sep 19, 2024
507e32f
Merge branch 'main' into ai21-ops
the-praxs Sep 20, 2024
0801727
cleanup and add tool event
the-praxs Sep 20, 2024
bc40171
structure examples notebook
the-praxs Sep 20, 2024
889bb1a
Merge branch 'main' into ai21-ops
the-praxs Sep 22, 2024
9100dfa
add contextual answers tracking
the-praxs Sep 22, 2024
857f7bb
cleanup example notebook
the-praxs Sep 22, 2024
b687523
create testing file
the-praxs Sep 22, 2024
6995157
clean example notebook again
the-praxs Sep 22, 2024
33e27fe
Merge branch 'main' into ai21-ops
the-praxs Sep 27, 2024
d87b0e8
Merge branch 'main' into ai21-ops
the-praxs Oct 4, 2024
243a878
Merge branch 'main' into ai21-ops
the-praxs Oct 15, 2024
1f23886
Merge branch 'main' into ai21-ops
the-praxs Oct 24, 2024
71e140e
Merge branch 'main' into ai21-ops
the-praxs Nov 1, 2024
f20598c
rename examples directory
the-praxs Nov 1, 2024
1801171
Merge branch 'main' into ai21-ops
the-praxs Nov 4, 2024
dddbfcb
updated docs page
areibman Nov 4, 2024
b661c6e
Merge branch 'main' into ai21-ops
the-praxs Nov 4, 2024
fb2b531
wrap `chunk.choices[0].delta.content` in `str(...)`
the-praxs Nov 4, 2024
c23fb6f
update doc
areibman Nov 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions agentops/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .ollama import OllamaProvider
from .openai import OpenAiProvider
from .anthropic import AnthropicProvider
from .ai21 import AI21Provider

original_func = {}
original_create = None
Expand All @@ -39,6 +40,12 @@ class LlmTracker:
"anthropic": {
"0.32.0": ("completions.create",),
},
"ai21": {
"2.0.0": (
"chat.completions.create",
"client.answer.create",
),
},
}

def __init__(self, client):
Expand Down Expand Up @@ -135,10 +142,27 @@ def override_api(self):
f"Only Anthropic>=0.32.0 supported. v{module_version} found."
)

if api == "ai21":
module_version = version(api)

if module_version is None:
logger.warning(
f"Cannot determine AI21 version. Only AI21>=2.0.0 supported."
)

if Version(module_version) >= parse("2.0.0"):
provider = AI21Provider(self.client)
provider.override()
else:
logger.warning(
f"Only AI21>=2.0.0 supported. v{module_version} found."
)

def stop_instrumenting(self):
OpenAiProvider(self.client).undo_override()
GroqProvider(self.client).undo_override()
CohereProvider(self.client).undo_override()
LiteLLMProvider(self.client).undo_override()
OllamaProvider(self.client).undo_override()
AnthropicProvider(self.client).undo_override()
AI21Provider(self.client).undo_override()
251 changes: 251 additions & 0 deletions agentops/llms/ai21.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
import inspect
import pprint
from typing import Optional

from agentops.llms.instrumented_provider import InstrumentedProvider
from agentops.time_travel import fetch_completion_override_from_time_travel_cache

from ..event import ErrorEvent, LLMEvent, ActionEvent, ToolEvent
from ..session import Session
from ..log_config import logger
from ..helpers import check_call_stack_for_agent_id, get_ISO_time
from ..singleton import singleton


@singleton
class AI21Provider(InstrumentedProvider):

original_create = None
original_create_async = None
original_answer = None
original_answer_async = None

def __init__(self, client):
super().__init__(client)
self._provider_name = "AI21"

def handle_response(
self, response, kwargs, init_timestamp, session: Optional[Session] = None
):
"""Handle responses for AI21"""
from ai21.stream.stream import Stream
from ai21.stream.async_stream import AsyncStream
from ai21.models.chat.chat_completion_chunk import ChatCompletionChunk
from ai21.models.chat.chat_completion_response import ChatCompletionResponse
from ai21.models.responses.answer_response import AnswerResponse

llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs)
action_event = ActionEvent(init_timestamp=init_timestamp, params=kwargs)

if session is not None:
llm_event.session_id = session.session_id

def handle_stream_chunk(chunk: ChatCompletionChunk):
# We take the first ChatCompletionChunk and accumulate the deltas from all subsequent chunks to build one full chat completion
if llm_event.returns is None:
llm_event.returns = chunk
# Manually setting content to empty string to avoid error
llm_event.returns.choices[0].delta.content = ""

try:
accumulated_delta = llm_event.returns.choices[0].delta
llm_event.agent_id = check_call_stack_for_agent_id()
llm_event.model = kwargs["model"]
llm_event.prompt = [
message.model_dump() for message in kwargs["messages"]
]

# NOTE: We assume for completion only choices[0] is relevant
choice = chunk.choices[0]

if choice.delta.content:
accumulated_delta.content += choice.delta.content

if choice.delta.role:
accumulated_delta.role = choice.delta.role

if getattr("choice.delta", "tool_calls", None):
accumulated_delta.tool_calls += ToolEvent(logs=choice.delta.tools)

if choice.finish_reason:
# Streaming is done. Record LLMEvent
llm_event.returns.choices[0].finish_reason = choice.finish_reason
llm_event.completion = {
"role": accumulated_delta.role,
"content": accumulated_delta.content,
}
llm_event.prompt_tokens = chunk.usage.prompt_tokens
llm_event.completion_tokens = chunk.usage.completion_tokens
llm_event.end_timestamp = get_ISO_time()
self._safe_record(session, llm_event)

except Exception as e:
self._safe_record(
session, ErrorEvent(trigger_event=llm_event, exception=e)
)

kwargs_str = pprint.pformat(kwargs)
chunk = pprint.pformat(chunk)
logger.warning(
f"Unable to parse a chunk for LLM call. Skipping upload to AgentOps\n"
f"chunk:\n {chunk}\n"
f"kwargs:\n {kwargs_str}\n"
)

# if the response is a generator, decorate the generator
# For synchronous Stream
if isinstance(response, Stream):

def generator():
for chunk in response:
handle_stream_chunk(chunk)
yield chunk

return generator()

# For asynchronous AsyncStream
if isinstance(response, AsyncStream):

async def async_generator():
async for chunk in response:
handle_stream_chunk(chunk)
yield chunk

return async_generator()

# Handle object responses
try:
if isinstance(response, ChatCompletionResponse):
llm_event.returns = response
llm_event.agent_id = check_call_stack_for_agent_id()
llm_event.model = kwargs["model"]
llm_event.prompt = [
message.model_dump() for message in kwargs["messages"]
]
llm_event.prompt_tokens = response.usage.prompt_tokens
llm_event.completion = response.choices[0].message.model_dump()
llm_event.completion_tokens = response.usage.completion_tokens
llm_event.end_timestamp = get_ISO_time()
self._safe_record(session, llm_event)

elif isinstance(response, AnswerResponse):
action_event.returns = response
action_event.agent_id = check_call_stack_for_agent_id()
action_event.action_type = "Contextual Answers"
action_event.logs = [
{"context": kwargs["context"], "question": kwargs["question"]},
response.model_dump() if response.model_dump() else None,
]
action_event.end_timestamp = get_ISO_time()
self._safe_record(session, action_event)

except Exception as e:
self._safe_record(session, ErrorEvent(trigger_event=llm_event, exception=e))
kwargs_str = pprint.pformat(kwargs)
response = pprint.pformat(response)
logger.warning(
f"Unable to parse response for LLM call. Skipping upload to AgentOps\n"
f"response:\n {response}\n"
f"kwargs:\n {kwargs_str}\n"
)

return response

def override(self):
self._override_completion()
self._override_completion_async()
self._override_answer()
self._override_answer_async()

def _override_completion(self):
from ai21.clients.studio.resources.chat import ChatCompletions

global original_create
original_create = ChatCompletions.create

def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()
session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]
result = original_create(*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)

# Override the original method with the patched one
ChatCompletions.create = patched_function

def _override_completion_async(self):
from ai21.clients.studio.resources.chat import AsyncChatCompletions

global original_create_async
original_create_async = AsyncChatCompletions.create

async def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()
session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]
result = await original_create_async(*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)

# Override the original method with the patched one
AsyncChatCompletions.create = patched_function

def _override_answer(self):
from ai21.clients.studio.resources.studio_answer import StudioAnswer

global original_answer
original_answer = StudioAnswer.create

def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()

session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]
result = original_answer(*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)

StudioAnswer.create = patched_function

def _override_answer_async(self):
from ai21.clients.studio.resources.studio_answer import AsyncStudioAnswer

global original_answer_async
original_answer_async = AsyncStudioAnswer.create

async def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()

session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]
result = await original_answer_async(*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)

AsyncStudioAnswer.create = patched_function

def undo_override(self):
if (
self.original_create is not None
and self.original_create_async is not None
and self.original_answer is not None
and self.original_answer_async is not None
):
from ai21.clients.studio.resources.chat import (
ChatCompletions,
AsyncChatCompletions,
)
from ai21.clients.studio.resources.studio_answer import (
StudioAnswer,
AsyncStudioAnswer,
)

ChatCompletions.create = self.original_create
AsyncChatCompletions.create = self.original_create_async
StudioAnswer.create = self.original_answer
AsyncStudioAnswer.create = self.original_answer_async
Loading
Loading