-
Notifications
You must be signed in to change notification settings - Fork 265
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Revert "Add integration tests for Anthropic/Cohere/Groq/Litellm/Mistr…
- Loading branch information
Showing
8 changed files
with
306 additions
and
597 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,110 +1,64 @@ | ||
import asyncio | ||
import os | ||
|
||
import agentops | ||
from dotenv import load_dotenv | ||
import os | ||
import ai21 | ||
from ai21.models.chat import ChatMessage | ||
from ai21.clients.studio.resources.chat import ChatCompletions, AsyncChatCompletions | ||
|
||
load_dotenv() | ||
agentops.init(default_tags=["ai21-provider-test"]) | ||
|
||
api_key = os.getenv("AI2I_API_KEY") | ||
ai21_client = ai21.AI21Client(api_key=api_key) | ||
async_ai21_client = ai21.AsyncAI21Client(api_key=api_key) | ||
|
||
messages = [ | ||
ChatMessage(content="You are an expert mathematician.", role="system"), | ||
ChatMessage( | ||
content="Write a summary of 5 lines on the Shockley diode equation.", | ||
role="user", | ||
), | ||
] | ||
|
||
response = ai21_client.chat.completions.create( | ||
model="jamba-1.5-mini", | ||
messages=messages, | ||
) | ||
|
||
|
||
stream_response = ai21_client.chat.completions.create( | ||
model="jamba-1.5-mini", | ||
messages=messages, | ||
stream=True, | ||
) | ||
|
||
response = "" | ||
for chunk in stream_response: | ||
response += chunk.choices[0].delta.content | ||
print(response) | ||
|
||
|
||
async def async_test(): | ||
async_response = await async_ai21_client.chat.completions.create( | ||
model="jamba-1.5-mini", | ||
messages=messages, | ||
) | ||
print(async_response) | ||
|
||
|
||
asyncio.run(async_test()) | ||
|
||
agentops.stop_instrumenting() | ||
|
||
untracked_response = ai21_client.chat.completions.create( | ||
model="jamba-1.5-mini", | ||
messages=messages, | ||
) | ||
|
||
|
||
agentops.end_session(end_state="Success") | ||
|
||
# Check for required API key | ||
if not os.getenv("AI21_API_KEY"): | ||
raise ValueError("AI21_API_KEY environment variable is required") | ||
|
||
def test_ai21_integration(): | ||
"""Integration test demonstrating all four AI21 call patterns: | ||
1. Sync (non-streaming) | ||
2. Sync (streaming) | ||
3. Async (non-streaming) | ||
4. Async (streaming) | ||
Verifies that AgentOps correctly tracks all LLM calls via analytics. | ||
""" | ||
# Initialize AgentOps without auto-starting session | ||
agentops.init(auto_start_session=False) | ||
session = agentops.start_session() | ||
|
||
api_key = os.getenv("AI21_API_KEY") | ||
# Initialize provider | ||
from agentops.llms.providers.ai21 import AI21Provider | ||
provider = AI21Provider(None) # AI21 doesn't need a client instance | ||
provider.override() | ||
|
||
# Pass session to provider | ||
provider.client = session | ||
ai21_client = ai21.AI21Client(api_key=api_key) | ||
async_ai21_client = ai21.AsyncAI21Client(api_key=api_key) | ||
chat_client = ChatCompletions(client=ai21_client) | ||
async_chat_client = AsyncChatCompletions(client=async_ai21_client) | ||
|
||
# Create message objects | ||
base_messages = [ | ||
ChatMessage(role="system", content="You are a helpful AI assistant"), | ||
ChatMessage(role="user", content="Hello from the test suite") | ||
] | ||
sync_messages = base_messages.copy() | ||
sync_stream_messages = base_messages.copy() | ||
async_messages = base_messages.copy() | ||
async_stream_messages = base_messages.copy() | ||
|
||
def sync_no_stream(): | ||
chat_client.create( | ||
model="jamba-instruct", | ||
system="You are a helpful AI assistant", | ||
messages=sync_messages, | ||
maxTokens=10 | ||
) | ||
|
||
def sync_stream(): | ||
stream_response = chat_client.create( | ||
model="jamba-instruct", | ||
system="You are a helpful AI assistant", | ||
messages=sync_stream_messages, | ||
maxTokens=10, | ||
stream=True | ||
) | ||
for chunk in stream_response: | ||
_ = chunk.choices[0].delta.content if hasattr(chunk.choices[0].delta, "content") else "" | ||
|
||
async def async_no_stream(): | ||
await async_chat_client.create( | ||
model="jamba-instruct", | ||
system="You are a helpful AI assistant", | ||
messages=async_messages, | ||
maxTokens=10 | ||
) | ||
|
||
async def async_stream(): | ||
async_stream_response = await async_chat_client.create( | ||
model="jamba-instruct", | ||
system="You are a helpful AI assistant", | ||
messages=async_stream_messages, | ||
maxTokens=10, | ||
stream=True | ||
) | ||
async for chunk in async_stream_response: | ||
_ = chunk.choices[0].delta.content if hasattr(chunk.choices[0].delta, 'content') else '' | ||
|
||
async def run_async_tests(): | ||
await async_no_stream() | ||
await async_stream() | ||
|
||
# Call each function with proper error handling | ||
try: | ||
sync_no_stream() | ||
sync_stream() | ||
asyncio.run(run_async_tests()) | ||
except Exception as e: | ||
print(f"Error during AI21 test: {str(e)}") | ||
raise | ||
finally: | ||
session.end_session("Success") | ||
analytics = session.get_analytics() | ||
print(f"Analytics: {analytics}") | ||
assert analytics["LLM calls"] >= 4, f"Expected at least 4 LLM calls, but got {analytics['LLM calls']}" | ||
|
||
# Analytics verification is handled in the finally block | ||
|
||
if __name__ == "__main__": | ||
test_ai21_integration() | ||
### | ||
# Used to verify that one session is created with one LLM event | ||
### |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,110 +1,78 @@ | ||
import asyncio | ||
import os | ||
|
||
import agentops | ||
from dotenv import load_dotenv | ||
import anthropic | ||
|
||
load_dotenv() | ||
agentops.init(default_tags=["anthropic-provider-test"]) | ||
anthropic_client = anthropic.Anthropic() | ||
async_anthropic_client = anthropic.AsyncAnthropic() | ||
|
||
response = anthropic_client.messages.create( | ||
max_tokens=1024, | ||
model="claude-3-5-sonnet-20240620", | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "say hi", | ||
} | ||
], | ||
) | ||
|
||
|
||
def test_anthropic_integration(): | ||
"""Integration test demonstrating all four Anthropic call patterns: | ||
1. Sync (non-streaming) | ||
2. Sync (streaming) | ||
3. Async (non-streaming) | ||
4. Async (streaming) | ||
stream_response = anthropic_client.messages.create( | ||
max_tokens=1024, | ||
model="claude-3-5-sonnet-20240620", | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "say hi 2", | ||
} | ||
], | ||
stream=True, | ||
) | ||
|
||
Verifies that AgentOps correctly tracks all LLM calls via analytics. | ||
""" | ||
# Initialize AgentOps without auto-starting session | ||
agentops.init(auto_start_session=False) | ||
session = agentops.start_session() | ||
response = "" | ||
for event in stream_response: | ||
if event.type == "content_block_delta": | ||
response += event.delta.text | ||
elif event.type == "message_stop": | ||
print(response) | ||
|
||
# Initialize clients and provider | ||
anthropic_client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) | ||
async_anthropic_client = anthropic.AsyncAnthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) | ||
from agentops.llms.providers.anthropic import AnthropicProvider | ||
provider = AnthropicProvider(anthropic_client) | ||
provider.override() | ||
|
||
# Pass session to provider | ||
provider.client = session | ||
|
||
def sync_no_stream(): | ||
anthropic_client.messages.create( | ||
max_tokens=1024, | ||
model="claude-3-5-sonnet-20240620", | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "Hello from sync no stream", | ||
} | ||
], | ||
session=session | ||
) | ||
async def async_test(): | ||
async_response = await async_anthropic_client.messages.create( | ||
max_tokens=1024, | ||
model="claude-3-5-sonnet-20240620", | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "say hi 3", | ||
} | ||
], | ||
) | ||
print(async_response) | ||
|
||
def sync_stream(): | ||
stream_response = anthropic_client.messages.create( | ||
max_tokens=1024, | ||
model="claude-3-5-sonnet-20240620", | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "Hello from sync streaming", | ||
} | ||
], | ||
stream=True, | ||
session=session | ||
) | ||
for _ in stream_response: | ||
pass | ||
|
||
async def async_no_stream(): | ||
await async_anthropic_client.messages.create( | ||
max_tokens=1024, | ||
model="claude-3-5-sonnet-20240620", | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "Hello from async no stream", | ||
} | ||
], | ||
session=session | ||
) | ||
asyncio.run(async_test()) | ||
|
||
async def async_stream(): | ||
async_stream_response = await async_anthropic_client.messages.create( | ||
max_tokens=1024, | ||
model="claude-3-5-sonnet-20240620", | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "Hello from async streaming", | ||
} | ||
], | ||
stream=True, | ||
session=session | ||
) | ||
async for _ in async_stream_response: | ||
pass | ||
agentops.stop_instrumenting() | ||
|
||
async def run_async_tests(): | ||
await async_no_stream() | ||
await async_stream() | ||
untracked_response = anthropic_client.messages.create( | ||
max_tokens=1024, | ||
model="claude-3-5-sonnet-20240620", | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "say hi 4", | ||
} | ||
], | ||
) | ||
|
||
# Call each function with proper error handling | ||
try: | ||
sync_no_stream() | ||
sync_stream() | ||
asyncio.run(run_async_tests()) | ||
except Exception as e: | ||
print(f"Error during Anthropic test: {str(e)}") | ||
raise | ||
|
||
session.end_session("Success") | ||
analytics = session.get_analytics() | ||
print(analytics) | ||
# Verify that all LLM calls were tracked | ||
assert analytics["LLM calls"] >= 4, f"Expected at least 4 LLM calls, but got {analytics['LLM calls']}" | ||
agentops.end_session(end_state="Success") | ||
|
||
if __name__ == "__main__": | ||
test_anthropic_integration() | ||
### | ||
# Used to verify that one session is created with one LLM event | ||
### |
Oops, something went wrong.