-
Notifications
You must be signed in to change notification settings - Fork 259
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore: add base test class for provider tests
Co-Authored-By: Alex Reibman <[email protected]>
- Loading branch information
1 parent
1031bf5
commit f53f95e
Showing
1 changed file
with
42 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import pytest | ||
import asyncio | ||
import json | ||
from unittest.mock import AsyncMock, MagicMock | ||
|
||
from agentops.session import Session | ||
from agentops.client import Client | ||
from agentops.event import LLMEvent | ||
|
||
|
||
class BaseProviderTest: | ||
"""Base class for provider tests.""" | ||
|
||
async def async_setup_method(self, method): | ||
"""Set up test method.""" | ||
# Initialize mock client and session | ||
self.mock_req = AsyncMock() | ||
self.session = Session(client=Client(api_key="test-key")) | ||
self.session.client.http_client = self.mock_req | ||
|
||
async def teardown_method(self, method): | ||
"""Clean up after test.""" | ||
if hasattr(self, 'provider'): | ||
self.provider.undo_override() | ||
|
||
async def async_verify_events(self, session, expected_count=1): | ||
"""Verify events were recorded.""" | ||
await asyncio.sleep(0.1) # Allow time for async event processing | ||
create_events_requests = [req for req in self.mock_req.request_history if req.url.endswith("/v2/create_events")] | ||
assert len(create_events_requests) >= 1, "No events were recorded" | ||
request_body = json.loads(create_events_requests[-1].body.decode("utf-8")) | ||
assert "session_id" in request_body, "Session ID not found in request" | ||
|
||
async def async_verify_llm_event(self, mock_req, model=None): | ||
"""Verify LLM event was recorded.""" | ||
await asyncio.sleep(0.1) # Allow time for async event processing | ||
create_events_requests = [req for req in mock_req.request_history if req.url.endswith("/v2/create_events")] | ||
assert len(create_events_requests) >= 1, "No events were recorded" | ||
request_body = json.loads(create_events_requests[-1].body.decode("utf-8")) | ||
assert "event_type" in request_body and request_body["event_type"] == "llms", "LLM event not found" | ||
if model: | ||
assert "model" in request_body and request_body["model"] == model, f"Model {model} not found in event" |