Skip to content

Commit

Permalink
Anthropic beta (#497)
Browse files Browse the repository at this point in the history
* add anthropic beta support

* added beta support

* bump version

* extra char

* fix failing tests
  • Loading branch information
areibman authored Nov 10, 2024
1 parent 84a94dd commit bfef59e
Show file tree
Hide file tree
Showing 5 changed files with 1,598 additions and 81 deletions.
177 changes: 101 additions & 76 deletions agentops/llms/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def handle_response(
"""Handle responses for Anthropic"""
from anthropic import Stream, AsyncStream
from anthropic.resources import AsyncMessages
import anthropic.resources.beta.messages.messages as beta_messages
from anthropic.types import Message

llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs)
Expand Down Expand Up @@ -155,6 +156,7 @@ def override(self):

def _override_completion(self):
from anthropic.resources import messages
import anthropic.resources.beta.messages.messages as beta_messages
from anthropic.types import (
Message,
RawContentBlockDeltaEvent,
Expand All @@ -167,54 +169,64 @@ def _override_completion(self):

# Store the original method
self.original_create = messages.Messages.create
self.original_create_beta = beta_messages.Messages.create

def patched_function(*args, **kwargs):
init_timestamp = get_ISO_time()
session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]
def create_patched_function(is_beta=False):
def patched_function(*args, **kwargs):
init_timestamp = get_ISO_time()
session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]

completion_override = fetch_completion_override_from_time_travel_cache(
kwargs
)
if completion_override:
result_model = None
pydantic_models = (
Message,
RawContentBlockDeltaEvent,
RawContentBlockStartEvent,
RawContentBlockStopEvent,
RawMessageDeltaEvent,
RawMessageStartEvent,
RawMessageStopEvent,
completion_override = fetch_completion_override_from_time_travel_cache(
kwargs
)
if completion_override:
result_model = None
pydantic_models = (
Message,
RawContentBlockDeltaEvent,
RawContentBlockStartEvent,
RawContentBlockStopEvent,
RawMessageDeltaEvent,
RawMessageStartEvent,
RawMessageStopEvent,
)

for pydantic_model in pydantic_models:
try:
result_model = pydantic_model.model_validate_json(
completion_override
for pydantic_model in pydantic_models:
try:
result_model = pydantic_model.model_validate_json(
completion_override
)
break
except Exception as e:
pass

if result_model is None:
logger.error(
f"Time Travel: Pydantic validation failed for {pydantic_models} \n"
f"Time Travel: Completion override was:\n"
f"{pprint.pformat(completion_override)}"
)
break
except Exception as e:
pass

if result_model is None:
logger.error(
f"Time Travel: Pydantic validation failed for {pydantic_models} \n"
f"Time Travel: Completion override was:\n"
f"{pprint.pformat(completion_override)}"
return None
return self.handle_response(
result_model, kwargs, init_timestamp, session=session
)
return None

# Call the original function with its original arguments
original_func = (
self.original_create_beta if is_beta else self.original_create
)
result = original_func(*args, **kwargs)
return self.handle_response(
result_model, kwargs, init_timestamp, session=session
result, kwargs, init_timestamp, session=session
)

# Call the original function with its original arguments
result = self.original_create(*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)
return patched_function

# Override the original method with the patched one
messages.Messages.create = patched_function
# Override the original methods with the patched ones
messages.Messages.create = create_patched_function(is_beta=False)
beta_messages.Messages.create = create_patched_function(is_beta=True)

def _override_async_completion(self):
from anthropic.resources import messages
Expand All @@ -227,58 +239,71 @@ def _override_async_completion(self):
RawMessageStartEvent,
RawMessageStopEvent,
)
import anthropic.resources.beta.messages.messages as beta_messages

# Store the original method
self.original_create_async = messages.AsyncMessages.create
self.original_create_async_beta = beta_messages.AsyncMessages.create

async def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()
session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]
def create_patched_async_function(is_beta=False):
async def patched_function(*args, **kwargs):
init_timestamp = get_ISO_time()
session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]

completion_override = fetch_completion_override_from_time_travel_cache(
kwargs
)
if completion_override:
result_model = None
pydantic_models = (
Message,
RawContentBlockDeltaEvent,
RawContentBlockStartEvent,
RawContentBlockStopEvent,
RawMessageDeltaEvent,
RawMessageStartEvent,
RawMessageStopEvent,
completion_override = fetch_completion_override_from_time_travel_cache(
kwargs
)
if completion_override:
result_model = None
pydantic_models = (
Message,
RawContentBlockDeltaEvent,
RawContentBlockStartEvent,
RawContentBlockStopEvent,
RawMessageDeltaEvent,
RawMessageStartEvent,
RawMessageStopEvent,
)

for pydantic_model in pydantic_models:
try:
result_model = pydantic_model.model_validate_json(
completion_override
for pydantic_model in pydantic_models:
try:
result_model = pydantic_model.model_validate_json(
completion_override
)
break
except Exception as e:
pass

if result_model is None:
logger.error(
f"Time Travel: Pydantic validation failed for {pydantic_models} \n"
f"Time Travel: Completion override was:\n"
f"{pprint.pformat(completion_override)}"
)
break
except Exception as e:
pass

if result_model is None:
logger.error(
f"Time Travel: Pydantic validation failed for {pydantic_models} \n"
f"Time Travel: Completion override was:\n"
f"{pprint.pformat(completion_override)}"
return None

return self.handle_response(
result_model, kwargs, init_timestamp, session=session
)
return None

# Call the original function with its original arguments
original_func = (
self.original_create_async_beta
if is_beta
else self.original_create_async
)
result = await original_func(*args, **kwargs)
return self.handle_response(
result_model, kwargs, init_timestamp, session=session
result, kwargs, init_timestamp, session=session
)

result = await self.original_create_async(*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)
return patched_function

# Override the original method with the patched one
messages.AsyncMessages.create = patched_function
# Override the original methods with the patched ones
messages.AsyncMessages.create = create_patched_async_function(is_beta=False)
beta_messages.AsyncMessages.create = create_patched_async_function(is_beta=True)

def undo_override(self):
if self.original_create is not None and self.original_create_async is not None:
Expand Down
Loading

0 comments on commit bfef59e

Please sign in to comment.