From 132345e7f2c8c0fda3d3c4e299378e45a0607a8a Mon Sep 17 00:00:00 2001 From: Braelyn Boynton Date: Fri, 14 Jun 2024 14:34:42 -0700 Subject: [PATCH] worker and tracker changes --- agentops/client.py | 7 ++- agentops/llm_tracker.py | 26 +++++--- agentops/worker.py | 65 ++++++++++++-------- tests/core_manual_tests/multi_session_llm.py | 24 +++++--- 4 files changed, 78 insertions(+), 44 deletions(-) diff --git a/agentops/client.py b/agentops/client.py index bb2ec271..fa5ef64e 100644 --- a/agentops/client.py +++ b/agentops/client.py @@ -241,10 +241,10 @@ def record( event.trigger_event_id = event.trigger_event.id event.trigger_event_type = event.trigger_event.event_type - self._worker.add_event(event.trigger_event.__dict__, session_id) + self._worker.add_event(event.trigger_event.__dict__, session.session_id) event.trigger_event = None # removes trigger_event from serialization - self._worker.add_event(event.__dict__, session_id) + self._worker.add_event(event.__dict__, session.session_id) def _record_event_sync(self, func, event_name, *args, **kwargs): init_time = get_ISO_time() @@ -376,7 +376,8 @@ def start_session( host_env=get_host_env(self._env_data_opt_out), ) - self._worker = Worker(config or self.config) + if not self._worker: + self._worker = Worker(config or self.config) start_session_result = False if inherited_session_id is not None: diff --git a/agentops/llm_tracker.py b/agentops/llm_tracker.py index fdb161fa..16c02633 100644 --- a/agentops/llm_tracker.py +++ b/agentops/llm_tracker.py @@ -78,7 +78,8 @@ def handle_stream_chunk(chunk): self.client.record(self.llm_event, session_id=session_id) except Exception as e: self.client.record( - ErrorEvent(trigger_event=self.llm_event, exception=e) + ErrorEvent(trigger_event=self.llm_event, exception=e), + session_id=session_id, ) kwargs_str = pprint.pformat(kwargs) chunk = pprint.pformat(chunk) @@ -125,7 +126,10 @@ def generator(): self.client.record(self.llm_event, session_id=session_id) except Exception as e: - self.client.record(ErrorEvent(trigger_event=self.llm_event, exception=e)) + self.client.record( + ErrorEvent(trigger_event=self.llm_event, exception=e), + session_id=session_id, + ) kwargs_str = pprint.pformat(kwargs) response = pprint.pformat(response) logger.warning( @@ -191,7 +195,8 @@ def handle_stream_chunk(chunk: ChatCompletionChunk): self.client.record(self.llm_event, session_id=session_id) except Exception as e: self.client.record( - ErrorEvent(trigger_event=self.llm_event, exception=e) + ErrorEvent(trigger_event=self.llm_event, exception=e), + session_id=session_id, ) kwargs_str = pprint.pformat(kwargs) chunk = pprint.pformat(chunk) @@ -243,7 +248,10 @@ async def async_generator(): self.client.record(self.llm_event, session_id=session_id) except Exception as e: - self.client.record(ErrorEvent(trigger_event=self.llm_event, exception=e)) + self.client.record( + ErrorEvent(trigger_event=self.llm_event, exception=e), + session_id=session_id, + ) kwargs_str = pprint.pformat(kwargs) response = pprint.pformat(response) logger.warning( @@ -331,7 +339,7 @@ def handle_stream_chunk(chunk): action_event.end_timestamp = get_ISO_time() for key, action_event in self.action_events.items(): - self.client.record(action_event) + self.client.record(action_event, session_id=session_id) elif isinstance(chunk, StreamedChatResponse_TextGeneration): self.llm_event.completion += chunk.text @@ -358,7 +366,8 @@ def handle_stream_chunk(chunk): except Exception as e: self.client.record( - ErrorEvent(trigger_event=self.llm_event, exception=e) + ErrorEvent(trigger_event=self.llm_event, exception=e), + session_id=session_id, ) kwargs_str = pprint.pformat(kwargs) chunk = pprint.pformat(chunk) @@ -418,7 +427,10 @@ def generator(): self.client.record(self.llm_event, session_id=session_id) except Exception as e: - self.client.record(ErrorEvent(trigger_event=self.llm_event, exception=e)) + self.client.record( + ErrorEvent(trigger_event=self.llm_event, exception=e), + session_id=session_id, + ) kwargs_str = pprint.pformat(kwargs) response = pprint.pformat(response) logger.warning( diff --git a/agentops/worker.py b/agentops/worker.py index cdc9aadd..c3a60049 100644 --- a/agentops/worker.py +++ b/agentops/worker.py @@ -1,4 +1,6 @@ import json +from uuid import UUID + from .log_config import logger import threading import time @@ -6,47 +8,59 @@ from .config import ClientConfiguration from .session import Session from .helpers import safe_serialize, filter_unjsonable -from typing import Dict, Optional, List +from typing import Dict, Optional, List, Union +import copy + + +class QueueSession: + events: List[Dict] = [] + jwt: str = None class Worker: def __init__(self, config: ClientConfiguration) -> None: self.config = config - self.queue: Dict[str, List[Dict]] = {} + self.queue: Dict[str, QueueSession] = {} self.lock = threading.Lock() self.stop_flag = threading.Event() self.thread = threading.Thread(target=self.run) self.thread.daemon = True self.thread.start() - self.jwt = None - def add_event(self, event: dict, session_id: str) -> None: + def add_event(self, event: dict, session_id: Union[str, UUID]) -> None: + session_id = str(session_id) with self.lock: if session_id in self.queue.keys(): - self.queue[session_id].append(event) + self.queue[session_id].events.append(event) else: - self.queue[session_id] = [event] + self.queue[session_id].events = [event] - if len(self.queue[session_id]) >= self.config.max_queue_size: + if len(self.queue[session_id].events) >= self.config.max_queue_size: self.flush_queue() def flush_queue(self) -> None: + print("flushing queue") with self.lock: - queue_copy = dict(self.queue) # Copy the current items - self.queue.clear() + queue_copy = copy.deepcopy(self.queue) # Copy the current items + + # clear events from queue + for session_id in self.queue.keys(): + self.queue[session_id].events = [] + if len(queue_copy.keys()) > 0: - for session_id, events in queue_copy.items(): - if len(queue_copy[session_id]) > 0: + for session_id, queue_session in queue_copy.items(): + if len(queue_copy[session_id].events) > 0: payload = { - "session_id": session_id, - "events": events, + "events": queue_session.events, } + print(payload) + serialized_payload = safe_serialize(payload).encode("utf-8") HttpClient.post( f"{self.config.endpoint}/v2/create_events", serialized_payload, - jwt=self.jwt, + jwt=queue_session.jwt, ) logger.debug("\n") @@ -54,7 +68,7 @@ def flush_queue(self) -> None: logger.debug(serialized_payload) logger.debug("\n") - def reauthorize_jwt(self, session: Session) -> bool: + def reauthorize_jwt(self, session: Session) -> Union[str, None]: with self.lock: payload = {"session_id": session.session_id} serialized_payload = json.dumps(filter_unjsonable(payload)).encode("utf-8") @@ -67,16 +81,16 @@ def reauthorize_jwt(self, session: Session) -> bool: logger.debug(res.body) if res.code != 200: - return False + return None - self.jwt = res.body.get("jwt", None) - if self.jwt is None: - return False - - return True + jwt = res.body.get("jwt", None) + self.queue[str(session.session_id)].jwt = jwt + return jwt def start_session(self, session: Session) -> bool: - self._session = session + print(f"adding {str(session.session_id)} to queue") + self.queue[str(session.session_id)] = QueueSession() + print(self.queue) with self.lock: payload = {"session": session.__dict__} serialized_payload = json.dumps(filter_unjsonable(payload)).encode("utf-8") @@ -92,8 +106,9 @@ def start_session(self, session: Session) -> bool: if res.code != 200: return False - self.jwt = res.body.get("jwt", None) - if self.jwt is None: + jwt = res.body.get("jwt", None) + self.queue[str(session.session_id)].jwt = jwt + if jwt is None: return False return True @@ -109,7 +124,7 @@ def end_session(self, session: Session) -> str: res = HttpClient.post( f"{self.config.endpoint}/v2/update_session", json.dumps(filter_unjsonable(payload)).encode("utf-8"), - jwt=self.jwt, + jwt=self.queue[str(session.session_id)].jwt, ) logger.debug(res.body) return res.body.get("token_cost", "unknown") diff --git a/tests/core_manual_tests/multi_session_llm.py b/tests/core_manual_tests/multi_session_llm.py index 6f481af8..1baa932a 100644 --- a/tests/core_manual_tests/multi_session_llm.py +++ b/tests/core_manual_tests/multi_session_llm.py @@ -1,13 +1,17 @@ import agentops from openai import OpenAI from dotenv import load_dotenv +from agentops import ActionEvent load_dotenv() -agentops.init() +agentops.init(auto_start_session=False) openai = OpenAI() session_id_1 = agentops.start_session(tags=["multi-session-test-1"]) -session_id_2 = agentops.start_session(tags=["multi-session-test-2"]) +# session_id_2 = agentops.start_session(tags=["multi-session-test-2"]) + +print("session_id_1: {}".format(session_id_1)) +# print("session_id_2: {}".format(session_id_2)) messages = [{"role": "user", "content": "Hello"}] @@ -18,15 +22,17 @@ session_id=session_id_1, # <-- add the agentops session_id to the create function ) -response = openai.chat.completions.create( - model="gpt-3.5-turbo", - messages=messages, - temperature=0.5, - session_id=session_id_2, # <-- add the agentops session_id to the create function -) +agentops.record(ActionEvent(action_type="test event"), session_id=session_id_1) + +# response = openai.chat.completions.create( +# model="gpt-3.5-turbo", +# messages=messages, +# temperature=0.5, +# session_id=session_id_2, # <-- add the agentops session_id to the create function +# ) agentops.end_session(end_state="Success", session_id=session_id_1) -agentops.end_session(end_state="Success", session_id=session_id_2) +# agentops.end_session(end_state="Success", session_id=session_id_2) ### # Used to verify that two sessions are created, each with one LLM event