From 8e8d1732d45cbb09afef4d82bfdf150feb0f9d72 Mon Sep 17 00:00:00 2001 From: Teo Date: Tue, 26 Nov 2024 23:32:19 -0600 Subject: [PATCH] sir fix-a-lot Signed-off-by: Teo --- agentops/session/api.py | 8 +++- agentops/session/exporter.py | 76 ++++++++++++++++++------------------ agentops/session/session.py | 17 ++++++-- 3 files changed, 59 insertions(+), 42 deletions(-) diff --git a/agentops/session/api.py b/agentops/session/api.py index 124dc856..6a043b55 100644 --- a/agentops/session/api.py +++ b/agentops/session/api.py @@ -70,7 +70,13 @@ def batch(self, events: List[Union[Event, dict]]) -> Response: if res.status == HttpStatus.SUCCESS: for event in events: # Handle both Event objects and dictionaries - event_type = event.event_type if isinstance(event, Event) else event["event_type"] + if isinstance(event, Event): + event_type = event.event_type + else: + # For dict events, get the value that matches an EventType value + event_type = event["event_type"] + + # Use the enum value for counting if event_type in self.session.state.event_counts: self.session.state.event_counts[event_type] += 1 diff --git a/agentops/session/exporter.py b/agentops/session/exporter.py index 4e8c9cca..b85d6e3e 100644 --- a/agentops/session/exporter.py +++ b/agentops/session/exporter.py @@ -4,6 +4,7 @@ import sys import threading from abc import ABC +from dataclasses import asdict from datetime import datetime, timezone from decimal import ROUND_HALF_UP, Decimal from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Protocol, Sequence, Union, cast @@ -22,7 +23,7 @@ from agentops.config import Configuration from agentops.enums import EndState, EventType -from agentops.event import ErrorEvent, Event +from agentops.event import ActionEvent, ErrorEvent, Event, ToolEvent from agentops.exceptions import ApiServerException from agentops.helpers import filter_unjsonable, get_ISO_time, safe_serialize from agentops.http_client import HttpClient, Response @@ -108,60 +109,59 @@ def to_span_attributes(obj: Any) -> Dict[str, AttributeValue]: @staticmethod def from_span_attributes(attrs: Union[Dict[str, AttributeValue], Mapping[str, AttributeValue]]) -> Dict[str, Any]: """Convert span attributes back to a dictionary of event attributes""" - # Create a mutable copy of the attributes - attrs_dict = dict(attrs) - # Get the serialized event data try: - event_data = json.loads(str(attrs_dict.get("event.data", "{}"))) + event_data = json.loads(str(attrs.get("event.data", "{}"))) except json.JSONDecodeError: event_data = {} # Get timestamps, providing defaults if missing current_time = datetime.now(timezone.utc).isoformat() - init_timestamp = attrs_dict.get("event.timestamp", current_time) - end_timestamp = attrs_dict.get("event.end_timestamp", current_time) + init_timestamp = attrs.get("event.timestamp") or event_data.get("init_timestamp", current_time) + end_timestamp = attrs.get("event.end_timestamp") or event_data.get("end_timestamp", current_time) # Build base event structure - event = { - "id": attrs_dict.get("event.id", str(uuid4())), - "event_type": attrs_dict.get("event.type", event_data.get("event_type", "unknown")), + base_kwargs = { + "id": attrs.get("event.id", str(uuid4())), + "event_type": attrs.get("event.type", event_data.get("event_type", "unknown")), "init_timestamp": init_timestamp, "end_timestamp": end_timestamp, } - # Format event data based on category - if attrs_dict.get("event.category") == "actions": - # For action events, try multiple sources for action_type + if attrs.get("event.category") == "actions": action_type = ( - attrs_dict.get("event.action_type") # Try direct attribute first - or event_data.get("action_type") # Then try event data - or event_data.get("args", [None])[0] # Then try first arg - or "unknown_action" # Finally fall back to unknown + attrs.get("event.action_type") + or event_data.get("action_type") + or event_data.get("args", [None])[0] + or "unknown_action" ) - event.update( - { - "action_type": action_type, - "params": event_data.get("params", {}), - "returns": event_data.get("returns"), - } + return asdict( + ActionEvent( + action_type=str(action_type), + params=event_data.get("params", {}), + returns=event_data.get("returns"), + **base_kwargs, + ) ) - elif event["event_type"] == "tools": - event.update( - { - "name": event_data.get("name", event_data.get("tool_name", "unknown_tool")), - "params": event_data.get("params", {}), - "returns": event_data.get("returns"), - } + + elif base_kwargs["event_type"] == "tools": + return asdict( + ToolEvent( + name=event_data.get("name", event_data.get("tool_name", "unknown_tool")), + params=event_data.get("params", {}), + returns=event_data.get("returns"), + **base_kwargs, + ) ) - else: - # For other event types, include all data except what we already used - remaining_data = { - k: v for k, v in event_data.items() if k not in ["id", "timestamp", "end_timestamp", "type"] - } - event.update(remaining_data) - - return event + + # Default to base Event for other types + event = Event(**base_kwargs) + event_dict = asdict(event) + # Add any remaining data + event_dict.update( + {k: v for k, v in event_data.items() if k not in ["id", "timestamp", "end_timestamp", "type"]} + ) + return event_dict class SessionProtocol(Protocol): diff --git a/agentops/session/session.py b/agentops/session/session.py index 6f7a36b4..be10ff51 100644 --- a/agentops/session/session.py +++ b/agentops/session/session.py @@ -266,10 +266,19 @@ def record(self, event: Union[Event, ErrorEvent], flush_now=False) -> None: if not self.is_running: return - # Ensure event has all required base attributes + # Handle ErrorEvent separately since it doesn't inherit from Event + if isinstance(event, ErrorEvent): + if not hasattr(event, "timestamp"): + event.timestamp = get_ISO_time() + # ErrorEvent doesn't need other timestamp fields + self.state.recent_events.append(event) + self._record_otel_event(event, flush_now) + return + + # For regular Event types if not hasattr(event, "id"): event.id = uuid4() - if not hasattr(event, "init_timestamp"): + if not hasattr(event, "init_timestamp") or event.init_timestamp is None: event.init_timestamp = get_ISO_time() if not hasattr(event, "end_timestamp") or event.end_timestamp is None: event.end_timestamp = get_ISO_time() @@ -293,7 +302,9 @@ def _update_session(self) -> None: with self._locks["update_session"]: if not self.is_running: return - self.api.update_session() + response_body, _ = self.api.update_session() + if response_body and "token_cost" in response_body: + self.state.token_cost = Decimal(str(response_body["token_cost"])) def get_analytics(self) -> Dict[str, Union[int, str]]: """Get session analytics