Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

record tool #333

Merged
merged 2 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion agentops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from .client import Client
from .event import Event, ActionEvent, LLMEvent, ToolEvent, ErrorEvent
from .decorators import record_function, track_agent
from .decorators import record_function, track_agent, record_tool
from .helpers import check_agentops_update
from .log_config import logger
from .session import Session
Expand Down
140 changes: 138 additions & 2 deletions agentops/decorators.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import inspect
import functools
import inspect
from typing import Optional, Union
from uuid import uuid4

from .event import ActionEvent, ErrorEvent
from .event import ActionEvent, ErrorEvent, ToolEvent
from .helpers import check_call_stack_for_agent_id, get_ISO_time
from .session import Session
from .client import Client
Expand Down Expand Up @@ -146,6 +146,142 @@
return decorator


def record_tool(tool_name: str):
"""
Decorator to record a tool use event before and after a function call.
Usage:
- Tools: Records function parameters and return statements of the
function being decorated. Additionally, timing information about
the action is recorded
Args:
tool_name (str): The name of the event to record.
"""

def decorator(func):
if inspect.iscoroutinefunction(func):

@functools.wraps(func)
async def async_wrapper(*args, session: Optional[Session] = None, **kwargs):
init_time = get_ISO_time()
if "session" in kwargs.keys():
del kwargs["session"]

Check warning on line 167 in agentops/decorators.py

View check run for this annotation

Codecov / codecov/patch

agentops/decorators.py#L167

Added line #L167 was not covered by tests
if session is None:
if Client().is_multi_session:
raise ValueError(

Check warning on line 170 in agentops/decorators.py

View check run for this annotation

Codecov / codecov/patch

agentops/decorators.py#L170

Added line #L170 was not covered by tests
"If multiple sessions exists, `session` is a required parameter in the function decorated by @record_tool"
)
func_args = inspect.signature(func).parameters
arg_names = list(func_args.keys())
# Get default values
arg_values = {
name: func_args[name].default
for name in arg_names
if func_args[name].default is not inspect._empty
}
# Update with positional arguments
arg_values.update(dict(zip(arg_names, args)))
arg_values.update(kwargs)

event = ToolEvent(
params=arg_values,
init_timestamp=init_time,
agent_id=check_call_stack_for_agent_id(),
name=tool_name,
)

try:
returns = await func(*args, **kwargs)

# If the function returns multiple values, record them all in the same event
if isinstance(returns, tuple):
returns = list(returns)

Check warning on line 197 in agentops/decorators.py

View check run for this annotation

Codecov / codecov/patch

agentops/decorators.py#L197

Added line #L197 was not covered by tests

event.returns = returns

# NOTE: Will likely remove in future since this is tightly coupled. Adding it to see how useful we find it for now
# TODO: check if screenshot is the url string we expect it to be? And not e.g. "True"
if hasattr(returns, "screenshot"):
event.screenshot = returns.screenshot # type: ignore

Check warning on line 204 in agentops/decorators.py

View check run for this annotation

Codecov / codecov/patch

agentops/decorators.py#L204

Added line #L204 was not covered by tests

event.end_timestamp = get_ISO_time()

if session:
session.record(event)
else:
Client().record(event)

except Exception as e:
Client().record(ErrorEvent(trigger_event=event, exception=e))

Check warning on line 214 in agentops/decorators.py

View check run for this annotation

Codecov / codecov/patch

agentops/decorators.py#L213-L214

Added lines #L213 - L214 were not covered by tests

# Re-raise the exception
raise

Check warning on line 217 in agentops/decorators.py

View check run for this annotation

Codecov / codecov/patch

agentops/decorators.py#L217

Added line #L217 was not covered by tests

return returns

return async_wrapper
else:

@functools.wraps(func)
def sync_wrapper(*args, session: Optional[Session] = None, **kwargs):
init_time = get_ISO_time()
if "session" in kwargs.keys():
del kwargs["session"]

Check warning on line 228 in agentops/decorators.py

View check run for this annotation

Codecov / codecov/patch

agentops/decorators.py#L228

Added line #L228 was not covered by tests
if session is None:
if Client().is_multi_session:
raise ValueError(
"If multiple sessions exists, `session` is a required parameter in the function decorated by @record_tool"
)
func_args = inspect.signature(func).parameters
arg_names = list(func_args.keys())
# Get default values
arg_values = {
name: func_args[name].default
for name in arg_names
if func_args[name].default is not inspect._empty
}
# Update with positional arguments
arg_values.update(dict(zip(arg_names, args)))
arg_values.update(kwargs)

event = ToolEvent(
params=arg_values,
init_timestamp=init_time,
agent_id=check_call_stack_for_agent_id(),
name=tool_name,
)

try:
returns = func(*args, **kwargs)

# If the function returns multiple values, record them all in the same event
if isinstance(returns, tuple):
returns = list(returns)

Check warning on line 258 in agentops/decorators.py

View check run for this annotation

Codecov / codecov/patch

agentops/decorators.py#L258

Added line #L258 was not covered by tests

event.returns = returns

if hasattr(returns, "screenshot"):
event.screenshot = returns.screenshot # type: ignore

Check warning on line 263 in agentops/decorators.py

View check run for this annotation

Codecov / codecov/patch

agentops/decorators.py#L263

Added line #L263 was not covered by tests

event.end_timestamp = get_ISO_time()

if session:
session.record(event)
else:
Client().record(event)

except Exception as e:
Client().record(ErrorEvent(trigger_event=event, exception=e))

Check warning on line 273 in agentops/decorators.py

View check run for this annotation

Codecov / codecov/patch

agentops/decorators.py#L272-L273

Added lines #L272 - L273 were not covered by tests

# Re-raise the exception
raise

Check warning on line 276 in agentops/decorators.py

View check run for this annotation

Codecov / codecov/patch

agentops/decorators.py#L276

Added line #L276 was not covered by tests

return returns

return sync_wrapper

return decorator


def track_agent(name: Union[str, None] = None):
def decorator(obj):
if name:
Expand Down
228 changes: 228 additions & 0 deletions tests/test_record_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
import pytest
import requests_mock
import time
import agentops
from agentops.decorators import record_tool
from datetime import datetime

from agentops.helpers import clear_singletons
import contextlib

jwts = ["some_jwt", "some_jwt2", "some_jwt3"]


@pytest.fixture(autouse=True)
def setup_teardown():
clear_singletons()
yield
agentops.end_all_sessions() # teardown part


@contextlib.contextmanager
@pytest.fixture(autouse=True)
def mock_req():
with requests_mock.Mocker() as m:
url = "https://api.agentops.ai"
m.post(url + "/v2/create_events", text="ok")

# Use iter to create an iterator that can return the jwt values
jwt_tokens = iter(jwts)

# Use an inner function to change the response for each request
def create_session_response(request, context):
context.status_code = 200
return {"status": "success", "jwt": next(jwt_tokens)}

m.post(url + "/v2/create_session", json=create_session_response)
m.post(url + "/v2/update_session", json={"status": "success", "token_cost": 5})
m.post(url + "/v2/developer_errors", text="ok")

yield m


class TestRecordTool:
def setup_method(self):
self.url = "https://api.agentops.ai"
self.api_key = "11111111-1111-4111-8111-111111111111"
self.tool_name = "test_tool_name"
agentops.init(self.api_key, max_wait_time=5, auto_start_session=False)

def test_record_tool_decorator(self, mock_req):
agentops.start_session()

@record_tool(tool_name=self.tool_name)
def add_two(x, y):
return x + y

# Act
add_two(3, 4)
time.sleep(0.1)

# Assert
assert len(mock_req.request_history) == 2
assert mock_req.last_request.headers["X-Agentops-Api-Key"] == self.api_key
request_json = mock_req.last_request.json()
assert request_json["events"][0]["name"] == self.tool_name
assert request_json["events"][0]["params"] == {"x": 3, "y": 4}
assert request_json["events"][0]["returns"] == 7

agentops.end_session(end_state="Success")

def test_record_tool_decorator_multiple(self, mock_req):
agentops.start_session()

# Arrange
@record_tool(tool_name=self.tool_name)
def add_three(x, y, z=3):
return x + y + z

# Act
add_three(1, 2)
time.sleep(0.1)
add_three(1, 2)
time.sleep(0.1)

# Assert
assert len(mock_req.request_history) == 3
assert mock_req.last_request.headers["X-Agentops-Api-Key"] == self.api_key
request_json = mock_req.last_request.json()
assert request_json["events"][0]["name"] == self.tool_name
assert request_json["events"][0]["params"] == {"x": 1, "y": 2, "z": 3}
assert request_json["events"][0]["returns"] == 6

agentops.end_session(end_state="Success")

@pytest.mark.asyncio
async def test_async_tool_call(self, mock_req):
agentops.start_session()

@record_tool(self.tool_name)
async def async_add(x, y):
time.sleep(0.1)
return x + y

# Act
result = await async_add(3, 4)
time.sleep(0.1)

# Assert
assert result == 7
# Assert
assert len(mock_req.request_history) == 2
assert mock_req.last_request.headers["X-Agentops-Api-Key"] == self.api_key
request_json = mock_req.last_request.json()
assert request_json["events"][0]["name"] == self.tool_name
assert request_json["events"][0]["params"] == {"x": 3, "y": 4}
assert request_json["events"][0]["returns"] == 7

init = datetime.fromisoformat(request_json["events"][0]["init_timestamp"])
end = datetime.fromisoformat(request_json["events"][0]["end_timestamp"])

assert (end - init).total_seconds() >= 0.1

agentops.end_session(end_state="Success")

def test_multiple_sessions_sync(self, mock_req):
session_1 = agentops.start_session()
session_2 = agentops.start_session()
assert session_1 is not None
assert session_2 is not None

# Arrange
@record_tool(tool_name=self.tool_name)
def add_three(x, y, z=3):
return x + y + z

# Act
add_three(1, 2, session=session_1)
time.sleep(0.1)
add_three(1, 2, session=session_2)
time.sleep(0.1)

# Assert
assert len(mock_req.request_history) == 4

request_json = mock_req.last_request.json()
assert mock_req.last_request.headers["X-Agentops-Api-Key"] == self.api_key
assert mock_req.last_request.headers["Authorization"] == f"Bearer some_jwt2"
assert request_json["events"][0]["name"] == self.tool_name
assert request_json["events"][0]["params"] == {"x": 1, "y": 2, "z": 3}
assert request_json["events"][0]["returns"] == 6

second_last_request_json = mock_req.request_history[-2].json()
assert (
mock_req.request_history[-2].headers["X-Agentops-Api-Key"] == self.api_key
)
assert (
mock_req.request_history[-2].headers["Authorization"] == f"Bearer some_jwt"
)
assert second_last_request_json["events"][0]["name"] == self.tool_name
assert second_last_request_json["events"][0]["params"] == {
"x": 1,
"y": 2,
"z": 3,
}
assert second_last_request_json["events"][0]["returns"] == 6

session_1.end_session(end_state="Success")
session_2.end_session(end_state="Success")

@pytest.mark.asyncio
async def test_multiple_sessions_async(self, mock_req):
session_1 = agentops.start_session()
session_2 = agentops.start_session()
assert session_1 is not None
assert session_2 is not None

# Arrange
@record_tool(tool_name=self.tool_name)
async def async_add(x, y):
time.sleep(0.1)
return x + y

# Act
await async_add(1, 2, session=session_1)
time.sleep(0.1)
await async_add(1, 2, session=session_2)
time.sleep(0.1)

# Assert
assert len(mock_req.request_history) == 4

request_json = mock_req.last_request.json()
assert mock_req.last_request.headers["X-Agentops-Api-Key"] == self.api_key
assert mock_req.last_request.headers["Authorization"] == f"Bearer some_jwt2"
assert request_json["events"][0]["name"] == self.tool_name
assert request_json["events"][0]["params"] == {"x": 1, "y": 2}
assert request_json["events"][0]["returns"] == 3

second_last_request_json = mock_req.request_history[-2].json()
assert (
mock_req.request_history[-2].headers["X-Agentops-Api-Key"] == self.api_key
)
assert (
mock_req.request_history[-2].headers["Authorization"] == f"Bearer some_jwt"
)
assert second_last_request_json["events"][0]["name"] == self.tool_name
assert second_last_request_json["events"][0]["params"] == {
"x": 1,
"y": 2,
}
assert second_last_request_json["events"][0]["returns"] == 3

session_1.end_session(end_state="Success")
session_2.end_session(end_state="Success")

def test_require_session_if_multiple(self):
session_1 = agentops.start_session()
session_2 = agentops.start_session()

# Arrange
@record_tool(tool_name=self.tool_name)
def add_two(x, y):
time.sleep(0.1)
return x + y

with pytest.raises(ValueError):
# Act
add_two(1, 2)
Loading