Skip to content

Commit

Permalink
worker and tracker changes
Browse files Browse the repository at this point in the history
  • Loading branch information
bboynton97 committed Jun 14, 2024
1 parent e3b5cac commit 132345e
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 44 deletions.
7 changes: 4 additions & 3 deletions agentops/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,10 @@ def record(

event.trigger_event_id = event.trigger_event.id
event.trigger_event_type = event.trigger_event.event_type
self._worker.add_event(event.trigger_event.__dict__, session_id)
self._worker.add_event(event.trigger_event.__dict__, session.session_id)
event.trigger_event = None # removes trigger_event from serialization

self._worker.add_event(event.__dict__, session_id)
self._worker.add_event(event.__dict__, session.session_id)

def _record_event_sync(self, func, event_name, *args, **kwargs):
init_time = get_ISO_time()
Expand Down Expand Up @@ -376,7 +376,8 @@ def start_session(
host_env=get_host_env(self._env_data_opt_out),
)

self._worker = Worker(config or self.config)
if not self._worker:
self._worker = Worker(config or self.config)

start_session_result = False
if inherited_session_id is not None:
Expand Down
26 changes: 19 additions & 7 deletions agentops/llm_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def handle_stream_chunk(chunk):
self.client.record(self.llm_event, session_id=session_id)
except Exception as e:
self.client.record(
ErrorEvent(trigger_event=self.llm_event, exception=e)
ErrorEvent(trigger_event=self.llm_event, exception=e),
session_id=session_id,
)
kwargs_str = pprint.pformat(kwargs)
chunk = pprint.pformat(chunk)
Expand Down Expand Up @@ -125,7 +126,10 @@ def generator():

self.client.record(self.llm_event, session_id=session_id)
except Exception as e:
self.client.record(ErrorEvent(trigger_event=self.llm_event, exception=e))
self.client.record(
ErrorEvent(trigger_event=self.llm_event, exception=e),
session_id=session_id,
)
kwargs_str = pprint.pformat(kwargs)
response = pprint.pformat(response)
logger.warning(
Expand Down Expand Up @@ -191,7 +195,8 @@ def handle_stream_chunk(chunk: ChatCompletionChunk):
self.client.record(self.llm_event, session_id=session_id)
except Exception as e:
self.client.record(
ErrorEvent(trigger_event=self.llm_event, exception=e)
ErrorEvent(trigger_event=self.llm_event, exception=e),
session_id=session_id,
)
kwargs_str = pprint.pformat(kwargs)
chunk = pprint.pformat(chunk)
Expand Down Expand Up @@ -243,7 +248,10 @@ async def async_generator():

self.client.record(self.llm_event, session_id=session_id)
except Exception as e:
self.client.record(ErrorEvent(trigger_event=self.llm_event, exception=e))
self.client.record(
ErrorEvent(trigger_event=self.llm_event, exception=e),
session_id=session_id,
)
kwargs_str = pprint.pformat(kwargs)
response = pprint.pformat(response)
logger.warning(
Expand Down Expand Up @@ -331,7 +339,7 @@ def handle_stream_chunk(chunk):
action_event.end_timestamp = get_ISO_time()

for key, action_event in self.action_events.items():
self.client.record(action_event)
self.client.record(action_event, session_id=session_id)

elif isinstance(chunk, StreamedChatResponse_TextGeneration):
self.llm_event.completion += chunk.text
Expand All @@ -358,7 +366,8 @@ def handle_stream_chunk(chunk):

except Exception as e:
self.client.record(
ErrorEvent(trigger_event=self.llm_event, exception=e)
ErrorEvent(trigger_event=self.llm_event, exception=e),
session_id=session_id,
)
kwargs_str = pprint.pformat(kwargs)
chunk = pprint.pformat(chunk)
Expand Down Expand Up @@ -418,7 +427,10 @@ def generator():

self.client.record(self.llm_event, session_id=session_id)
except Exception as e:
self.client.record(ErrorEvent(trigger_event=self.llm_event, exception=e))
self.client.record(
ErrorEvent(trigger_event=self.llm_event, exception=e),
session_id=session_id,
)
kwargs_str = pprint.pformat(kwargs)
response = pprint.pformat(response)
logger.warning(
Expand Down
65 changes: 40 additions & 25 deletions agentops/worker.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,74 @@
import json
from uuid import UUID

from .log_config import logger
import threading
import time
from .http_client import HttpClient
from .config import ClientConfiguration
from .session import Session
from .helpers import safe_serialize, filter_unjsonable
from typing import Dict, Optional, List
from typing import Dict, Optional, List, Union
import copy


class QueueSession:
events: List[Dict] = []
jwt: str = None


class Worker:
def __init__(self, config: ClientConfiguration) -> None:
self.config = config
self.queue: Dict[str, List[Dict]] = {}
self.queue: Dict[str, QueueSession] = {}
self.lock = threading.Lock()
self.stop_flag = threading.Event()
self.thread = threading.Thread(target=self.run)
self.thread.daemon = True
self.thread.start()
self.jwt = None

def add_event(self, event: dict, session_id: str) -> None:
def add_event(self, event: dict, session_id: Union[str, UUID]) -> None:
session_id = str(session_id)
with self.lock:
if session_id in self.queue.keys():
self.queue[session_id].append(event)
self.queue[session_id].events.append(event)
else:
self.queue[session_id] = [event]
self.queue[session_id].events = [event]

if len(self.queue[session_id]) >= self.config.max_queue_size:
if len(self.queue[session_id].events) >= self.config.max_queue_size:
self.flush_queue()

def flush_queue(self) -> None:
print("flushing queue")
with self.lock:
queue_copy = dict(self.queue) # Copy the current items
self.queue.clear()
queue_copy = copy.deepcopy(self.queue) # Copy the current items

# clear events from queue
for session_id in self.queue.keys():
self.queue[session_id].events = []

if len(queue_copy.keys()) > 0:
for session_id, events in queue_copy.items():
if len(queue_copy[session_id]) > 0:
for session_id, queue_session in queue_copy.items():
if len(queue_copy[session_id].events) > 0:
payload = {
"session_id": session_id,
"events": events,
"events": queue_session.events,
}

print(payload)

serialized_payload = safe_serialize(payload).encode("utf-8")
HttpClient.post(
f"{self.config.endpoint}/v2/create_events",
serialized_payload,
jwt=self.jwt,
jwt=queue_session.jwt,
)

logger.debug("\n<AGENTOPS_DEBUG_OUTPUT>")
logger.debug(f"Worker request to {self.config.endpoint}/events")
logger.debug(serialized_payload)
logger.debug("</AGENTOPS_DEBUG_OUTPUT>\n")

def reauthorize_jwt(self, session: Session) -> bool:
def reauthorize_jwt(self, session: Session) -> Union[str, None]:
with self.lock:
payload = {"session_id": session.session_id}
serialized_payload = json.dumps(filter_unjsonable(payload)).encode("utf-8")
Expand All @@ -67,16 +81,16 @@ def reauthorize_jwt(self, session: Session) -> bool:
logger.debug(res.body)

if res.code != 200:
return False
return None

self.jwt = res.body.get("jwt", None)
if self.jwt is None:
return False

return True
jwt = res.body.get("jwt", None)
self.queue[str(session.session_id)].jwt = jwt
return jwt

def start_session(self, session: Session) -> bool:
self._session = session
print(f"adding {str(session.session_id)} to queue")
self.queue[str(session.session_id)] = QueueSession()
print(self.queue)
with self.lock:
payload = {"session": session.__dict__}
serialized_payload = json.dumps(filter_unjsonable(payload)).encode("utf-8")
Expand All @@ -92,8 +106,9 @@ def start_session(self, session: Session) -> bool:
if res.code != 200:
return False

self.jwt = res.body.get("jwt", None)
if self.jwt is None:
jwt = res.body.get("jwt", None)
self.queue[str(session.session_id)].jwt = jwt
if jwt is None:
return False

return True
Expand All @@ -109,7 +124,7 @@ def end_session(self, session: Session) -> str:
res = HttpClient.post(
f"{self.config.endpoint}/v2/update_session",
json.dumps(filter_unjsonable(payload)).encode("utf-8"),
jwt=self.jwt,
jwt=self.queue[str(session.session_id)].jwt,
)
logger.debug(res.body)
return res.body.get("token_cost", "unknown")
Expand Down
24 changes: 15 additions & 9 deletions tests/core_manual_tests/multi_session_llm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import agentops
from openai import OpenAI
from dotenv import load_dotenv
from agentops import ActionEvent

load_dotenv()
agentops.init()
agentops.init(auto_start_session=False)
openai = OpenAI()

session_id_1 = agentops.start_session(tags=["multi-session-test-1"])
session_id_2 = agentops.start_session(tags=["multi-session-test-2"])
# session_id_2 = agentops.start_session(tags=["multi-session-test-2"])

print("session_id_1: {}".format(session_id_1))
# print("session_id_2: {}".format(session_id_2))

messages = [{"role": "user", "content": "Hello"}]

Expand All @@ -18,15 +22,17 @@
session_id=session_id_1, # <-- add the agentops session_id to the create function
)

response = openai.chat.completions.create(
model="gpt-3.5-turbo",
messages=messages,
temperature=0.5,
session_id=session_id_2, # <-- add the agentops session_id to the create function
)
agentops.record(ActionEvent(action_type="test event"), session_id=session_id_1)

# response = openai.chat.completions.create(
# model="gpt-3.5-turbo",
# messages=messages,
# temperature=0.5,
# session_id=session_id_2, # <-- add the agentops session_id to the create function
# )

agentops.end_session(end_state="Success", session_id=session_id_1)
agentops.end_session(end_state="Success", session_id=session_id_2)
# agentops.end_session(end_state="Success", session_id=session_id_2)

###
# Used to verify that two sessions are created, each with one LLM event
Expand Down

0 comments on commit 132345e

Please sign in to comment.