From 2fcc65836795023d97e6594f9b57022090d1ecaf Mon Sep 17 00:00:00 2001 From: Braelyn Boynton Date: Mon, 5 Aug 2024 11:58:04 -0700 Subject: [PATCH] record tool --- agentops/__init__.py | 2 +- agentops/decorators.py | 196 ++++++++++++-------------------- tests/test_record_tool.py | 228 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 298 insertions(+), 128 deletions(-) create mode 100644 tests/test_record_tool.py diff --git a/agentops/__init__.py b/agentops/__init__.py index fc00e71a..5798c366 100755 --- a/agentops/__init__.py +++ b/agentops/__init__.py @@ -4,7 +4,7 @@ from .client import Client from .event import Event, ActionEvent, LLMEvent, ToolEvent, ErrorEvent -from .decorators import record_function, track_agent +from .decorators import record_function, track_agent, record_tool from .helpers import check_agentops_update from .log_config import logger from .session import Session diff --git a/agentops/decorators.py b/agentops/decorators.py index ce0d4050..e6b071c8 100644 --- a/agentops/decorators.py +++ b/agentops/decorators.py @@ -1,147 +1,89 @@ import inspect -import functools from typing import Optional, Union from uuid import uuid4 -from .event import ActionEvent, ErrorEvent +from .event import ActionEvent, ErrorEvent, ToolEvent from .helpers import check_call_stack_for_agent_id, get_ISO_time from .session import Session from .client import Client from .log_config import logger -def record_function(event_name: str): - """ - Decorator to record an event before and after a function call. - Usage: - - Actions: Records function parameters and return statements of the - function being decorated. Additionally, timing information about - the action is recorded - Args: - event_name (str): The name of the event to record. - """ +def __record_action(func, event_name: str, EventClass): + def sync_wrapper(*args, session: Optional[Session] = None, **kwargs): + return __process_func(func, args, kwargs, session, event_name, EventClass) + + return sync_wrapper + + +def __record_action_async(func, event_name: str, EventClass): + async def async_wrapper(*args, session: Optional[Session] = None, **kwargs): + return await __process_func(func, args, kwargs, session, event_name, EventClass) + + return async_wrapper + + +def __process_func(func, args, kwargs, session, event_name, EventClass): + init_time = get_ISO_time() + if "session" in kwargs.keys(): + del kwargs["session"] + if session is None and Client().is_multi_session: + raise ValueError( + f"If multiple sessions exists, `session` is a required parameter in the function" + ) + func_args = inspect.signature(func).parameters + arg_names = list(func_args.keys()) + arg_values = { # Get default values + name: func_args[name].default + for name in arg_names + if func_args[name].default is not inspect._empty + } + arg_values.update(dict(zip(arg_names, args))) # Update with positional arguments + arg_values.update(kwargs) + + event = EventClass( + params=arg_values, + init_timestamp=init_time, + agent_id=check_call_stack_for_agent_id(), + action_type=event_name, + tool_name=event_name, + ) + + try: + returns = func(*args, **kwargs) + if isinstance( + returns, tuple + ): # If the function returns multiple values, record them all in the same event + returns = list(returns) + event.returns = returns + if hasattr(returns, "screenshot"): + event.screenshot = returns.screenshot # type: ignore + event.end_timestamp = get_ISO_time() + session.record(event) if session else Client().record(event) + + except Exception as e: + Client().record(ErrorEvent(trigger_event=event, exception=e)) + raise + + return returns + +def record_function(event_name: str): def decorator(func): if inspect.iscoroutinefunction(func): - - @functools.wraps(func) - async def async_wrapper(*args, session: Optional[Session] = None, **kwargs): - init_time = get_ISO_time() - if "session" in kwargs.keys(): - del kwargs["session"] - if session is None: - if Client().is_multi_session: - raise ValueError( - "If multiple sessions exists, `session` is a required parameter in the function decorated by @record_function" - ) - func_args = inspect.signature(func).parameters - arg_names = list(func_args.keys()) - # Get default values - arg_values = { - name: func_args[name].default - for name in arg_names - if func_args[name].default is not inspect._empty - } - # Update with positional arguments - arg_values.update(dict(zip(arg_names, args))) - arg_values.update(kwargs) - - event = ActionEvent( - params=arg_values, - init_timestamp=init_time, - agent_id=check_call_stack_for_agent_id(), - action_type=event_name, - ) - - try: - returns = await func(*args, **kwargs) - - # If the function returns multiple values, record them all in the same event - if isinstance(returns, tuple): - returns = list(returns) - - event.returns = returns - - # NOTE: Will likely remove in future since this is tightly coupled. Adding it to see how useful we find it for now - # TODO: check if screenshot is the url string we expect it to be? And not e.g. "True" - if hasattr(returns, "screenshot"): - event.screenshot = returns.screenshot # type: ignore - - event.end_timestamp = get_ISO_time() - - if session: - session.record(event) - else: - Client().record(event) - - except Exception as e: - Client().record(ErrorEvent(trigger_event=event, exception=e)) - - # Re-raise the exception - raise - - return returns - - return async_wrapper + return __record_action_async(func, event_name, ActionEvent) else: + return __record_action(func, event_name, ActionEvent) - @functools.wraps(func) - def sync_wrapper(*args, session: Optional[Session] = None, **kwargs): - init_time = get_ISO_time() - if "session" in kwargs.keys(): - del kwargs["session"] - if session is None: - if Client().is_multi_session: - raise ValueError( - "If multiple sessions exists, `session` is a required parameter in the function decorated by @record_function" - ) - func_args = inspect.signature(func).parameters - arg_names = list(func_args.keys()) - # Get default values - arg_values = { - name: func_args[name].default - for name in arg_names - if func_args[name].default is not inspect._empty - } - # Update with positional arguments - arg_values.update(dict(zip(arg_names, args))) - arg_values.update(kwargs) - - event = ActionEvent( - params=arg_values, - init_timestamp=init_time, - agent_id=check_call_stack_for_agent_id(), - action_type=event_name, - ) - - try: - returns = func(*args, **kwargs) - - # If the function returns multiple values, record them all in the same event - if isinstance(returns, tuple): - returns = list(returns) - - event.returns = returns - - if hasattr(returns, "screenshot"): - event.screenshot = returns.screenshot # type: ignore - - event.end_timestamp = get_ISO_time() - - if session: - session.record(event) - else: - Client().record(event) - - except Exception as e: - Client().record(ErrorEvent(trigger_event=event, exception=e)) - - # Re-raise the exception - raise + return decorator - return returns - return sync_wrapper +def record_tool(tool_name: str): + def decorator(func): + if inspect.iscoroutinefunction(func): + return __record_action_async(func, tool_name, ToolEvent) + else: + return __record_action(func, tool_name, ToolEvent) return decorator diff --git a/tests/test_record_tool.py b/tests/test_record_tool.py new file mode 100644 index 00000000..a0cf1d7c --- /dev/null +++ b/tests/test_record_tool.py @@ -0,0 +1,228 @@ +import pytest +import requests_mock +import time +import agentops +from agentops.decorators import record_tool +from datetime import datetime + +from agentops.helpers import clear_singletons +import contextlib + +jwts = ["some_jwt", "some_jwt2", "some_jwt3"] + + +@pytest.fixture(autouse=True) +def setup_teardown(): + clear_singletons() + yield + agentops.end_all_sessions() # teardown part + + +@contextlib.contextmanager +@pytest.fixture(autouse=True) +def mock_req(): + with requests_mock.Mocker() as m: + url = "https://api.agentops.ai" + m.post(url + "/v2/create_events", text="ok") + + # Use iter to create an iterator that can return the jwt values + jwt_tokens = iter(jwts) + + # Use an inner function to change the response for each request + def create_session_response(request, context): + context.status_code = 200 + return {"status": "success", "jwt": next(jwt_tokens)} + + m.post(url + "/v2/create_session", json=create_session_response) + m.post(url + "/v2/update_session", json={"status": "success", "token_cost": 5}) + m.post(url + "/v2/developer_errors", text="ok") + + yield m + + +class TestRecordAction: + def setup_method(self): + self.url = "https://api.agentops.ai" + self.api_key = "11111111-1111-4111-8111-111111111111" + self.tool_name = "test_tool_name" + agentops.init(self.api_key, max_wait_time=5, auto_start_session=False) + + def test_record_function_decorator(self, mock_req): + agentops.start_session() + + @record_tool(tool_name=self.tool_name) + def add_two(x, y): + return x + y + + # Act + add_two(3, 4) + time.sleep(0.1) + + # Assert + assert len(mock_req.request_history) == 2 + assert mock_req.last_request.headers["X-Agentops-Api-Key"] == self.api_key + request_json = mock_req.last_request.json() + assert request_json["events"][0]["tool_name"] == self.tool_name + assert request_json["events"][0]["params"] == {"x": 3, "y": 4} + assert request_json["events"][0]["returns"] == 7 + + agentops.end_session(end_state="Success") + + def test_record_function_decorator_multiple(self, mock_req): + agentops.start_session() + + # Arrange + @record_tool(tool_name=self.tool_name) + def add_three(x, y, z=3): + return x + y + z + + # Act + add_three(1, 2) + time.sleep(0.1) + add_three(1, 2) + time.sleep(0.1) + + # Assert + assert len(mock_req.request_history) == 3 + assert mock_req.last_request.headers["X-Agentops-Api-Key"] == self.api_key + request_json = mock_req.last_request.json() + assert request_json["events"][0]["tool_name"] == self.tool_name + assert request_json["events"][0]["params"] == {"x": 1, "y": 2, "z": 3} + assert request_json["events"][0]["returns"] == 6 + + agentops.end_session(end_state="Success") + + @pytest.mark.asyncio + async def test_async_function_call(self, mock_req): + agentops.start_session() + + @record_function(self.tool_name) + async def async_add(x, y): + time.sleep(0.1) + return x + y + + # Act + result = await async_add(3, 4) + time.sleep(0.1) + + # Assert + assert result == 7 + # Assert + assert len(mock_req.request_history) == 2 + assert mock_req.last_request.headers["X-Agentops-Api-Key"] == self.api_key + request_json = mock_req.last_request.json() + assert request_json["events"][0]["tool_name"] == self.tool_name + assert request_json["events"][0]["params"] == {"x": 3, "y": 4} + assert request_json["events"][0]["returns"] == 7 + + init = datetime.fromisoformat(request_json["events"][0]["init_timestamp"]) + end = datetime.fromisoformat(request_json["events"][0]["end_timestamp"]) + + assert (end - init).total_seconds() >= 0.1 + + agentops.end_session(end_state="Success") + + def test_multiple_sessions_sync(self, mock_req): + session_1 = agentops.start_session() + session_2 = agentops.start_session() + assert session_1 is not None + assert session_2 is not None + + # Arrange + @record_tool(tool_name=self.tool_name) + def add_three(x, y, z=3): + return x + y + z + + # Act + add_three(1, 2, session=session_1) + time.sleep(0.1) + add_three(1, 2, session=session_2) + time.sleep(0.1) + + # Assert + assert len(mock_req.request_history) == 4 + + request_json = mock_req.last_request.json() + assert mock_req.last_request.headers["X-Agentops-Api-Key"] == self.api_key + assert mock_req.last_request.headers["Authorization"] == f"Bearer some_jwt2" + assert request_json["events"][0]["tool_name"] == self.tool_name + assert request_json["events"][0]["params"] == {"x": 1, "y": 2, "z": 3} + assert request_json["events"][0]["returns"] == 6 + + second_last_request_json = mock_req.request_history[-2].json() + assert ( + mock_req.request_history[-2].headers["X-Agentops-Api-Key"] == self.api_key + ) + assert ( + mock_req.request_history[-2].headers["Authorization"] == f"Bearer some_jwt" + ) + assert second_last_request_json["events"][0]["tool_name"] == self.tool_name + assert second_last_request_json["events"][0]["params"] == { + "x": 1, + "y": 2, + "z": 3, + } + assert second_last_request_json["events"][0]["returns"] == 6 + + session_1.end_session(end_state="Success") + session_2.end_session(end_state="Success") + + @pytest.mark.asyncio + async def test_multiple_sessions_async(self, mock_req): + session_1 = agentops.start_session() + session_2 = agentops.start_session() + assert session_1 is not None + assert session_2 is not None + + # Arrange + @record_tool(tool_name=self.tool_name) + async def async_add(x, y): + time.sleep(0.1) + return x + y + + # Act + await async_add(1, 2, session=session_1) + time.sleep(0.1) + await async_add(1, 2, session=session_2) + time.sleep(0.1) + + # Assert + assert len(mock_req.request_history) == 4 + + request_json = mock_req.last_request.json() + assert mock_req.last_request.headers["X-Agentops-Api-Key"] == self.api_key + assert mock_req.last_request.headers["Authorization"] == f"Bearer some_jwt2" + assert request_json["events"][0]["tool_name"] == self.tool_name + assert request_json["events"][0]["params"] == {"x": 1, "y": 2} + assert request_json["events"][0]["returns"] == 3 + + second_last_request_json = mock_req.request_history[-2].json() + assert ( + mock_req.request_history[-2].headers["X-Agentops-Api-Key"] == self.api_key + ) + assert ( + mock_req.request_history[-2].headers["Authorization"] == f"Bearer some_jwt" + ) + assert second_last_request_json["events"][0]["tool_name"] == self.tool_name + assert second_last_request_json["events"][0]["params"] == { + "x": 1, + "y": 2, + } + assert second_last_request_json["events"][0]["returns"] == 3 + + session_1.end_session(end_state="Success") + session_2.end_session(end_state="Success") + + def test_require_session_if_multiple(self): + session_1 = agentops.start_session() + session_2 = agentops.start_session() + + # Arrange + @record_tool(tool_name=self.tool_name) + def add_two(x, y): + time.sleep(0.1) + return x + y + + with pytest.raises(ValueError): + # Act + add_two(1, 2)