diff --git a/agentops/__init__.py b/agentops/__init__.py index 6a38ed0c..cf45c02f 100755 --- a/agentops/__init__.py +++ b/agentops/__init__.py @@ -175,7 +175,6 @@ def record(event: Union[Event, ErrorEvent]): Client().record(event) -@check_init def add_tags(tags: List[str]): """ Append to session tags at runtime. @@ -187,7 +186,6 @@ def add_tags(tags: List[str]): Client().add_tags(tags) -@check_init def set_tags(tags: List[str]): """ Replace session tags at runtime. diff --git a/agentops/agent.py b/agentops/agent.py index 33bb3f9c..06873235 100644 --- a/agentops/agent.py +++ b/agentops/agent.py @@ -19,11 +19,15 @@ def new_init(self, *args, **kwargs): original_init(self, *args, **kwargs) self.agent_ops_agent_id = str(uuid4()) - if kwargs.get("session_id", None): - self.agent_ops_session_id = kwargs.get("session_id") + session_id = None + if kwargs.get("session", None): + session_id = kwargs.get("session").session_id + self.agent_ops_session_id = session_id Client().create_agent( - name=self.agent_ops_agent_name, agent_id=self.agent_ops_agent_id + name=self.agent_ops_agent_name, + agent_id=self.agent_ops_agent_id, + session_id=session_id, ) except AttributeError as e: logger.warning( diff --git a/agentops/client.py b/agentops/client.py index 329c4956..3c1eb32c 100644 --- a/agentops/client.py +++ b/agentops/client.py @@ -25,7 +25,9 @@ get_partner_frameworks, singleton, conditional_singleton, + safe_serialize, ) +from .http_client import HttpClient from .session import Session from .host_env import get_host_env from .log_config import logger @@ -193,10 +195,12 @@ def set_tags(self, tags: List[str]) -> None: Args: tags (List[str]): The list of tags to set. """ - self._tags_for_future_session = tags - session = self._safe_get_session() - session.set_tags(tags=tags) + try: + session = self._safe_get_session() + session.set_tags(tags=tags) + except ValueError: + self._tags_for_future_session = tags def record(self, event: Union[Event, ErrorEvent]) -> None: """ @@ -433,12 +437,34 @@ def end_session( self._sessions.remove(session) - def create_agent(self, name: str, agent_id: Optional[str] = None): + def create_agent( + self, + name: str, + agent_id: Optional[str] = None, + session: Optional[Session] = None, + ): if agent_id is None: agent_id = str(uuid4()) - if self._worker: - self._worker.create_agent(name=name, agent_id=agent_id) - return agent_id + + # if a session is passed in, use multi-session logic + if session: + return session.create_agent(name=name, agent_id=agent_id) + else: + # if no session passed, assume single session + session = self._safe_get_session() + payload = { + "id": agent_id, + "name": name, + } + + serialized_payload = safe_serialize(payload).encode("utf-8") + HttpClient.post( + f"{self.config.endpoint}/v2/create_agent", + serialized_payload, + jwt=session.jwt, + ) + + return agent_id def _handle_unclean_exits(self): def cleanup(end_state: str = "Fail", end_state_reason: Optional[str] = None): diff --git a/agentops/session.py b/agentops/session.py index e878675e..1a43e766 100644 --- a/agentops/session.py +++ b/agentops/session.py @@ -8,7 +8,7 @@ from .config import ClientConfiguration from .helpers import get_ISO_time, filter_unjsonable, safe_serialize from typing import Optional, List, Union -from uuid import UUID +from uuid import UUID, uuid4 from .http_client import HttpClient from .worker import Worker @@ -224,3 +224,19 @@ def run(self) -> None: time.sleep(self.config.max_wait_time / 1000) if self.queue: self.flush_queue() + + def create_agent(self, name, agent_id): + if agent_id is None: + agent_id = str(uuid4()) + + payload = { + "id": agent_id, + "name": name, + } + + serialized_payload = safe_serialize(payload).encode("utf-8") + HttpClient.post( + f"{self.config.endpoint}/v2/create_agent", serialized_payload, jwt=self.jwt + ) + + return agent_id diff --git a/agentops/worker.py b/agentops/worker.py index 1ce4807f..e814a994 100644 --- a/agentops/worker.py +++ b/agentops/worker.py @@ -65,18 +65,6 @@ def flush_queue(self) -> None: logger.debug(serialized_payload) logger.debug("\n") - def create_agent(self, agent_id, name, session_id: str): - payload = { - "id": agent_id, - "name": name, - "session_id": session_id, - } - - serialized_payload = safe_serialize(payload).encode("utf-8") - HttpClient.post( - f"{self.config.endpoint}/v2/create_agent", serialized_payload, jwt=self.jwt - ) - def run(self) -> None: while not self.stop_flag.is_set(): time.sleep(self.config.max_wait_time / 1000) diff --git a/tests/test_session.py b/tests/test_session.py index e1aa6d86..397d79a1 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -127,6 +127,20 @@ def test_inherit_session_id(self, mock_req): agentops.end_all_sessions() + def test_add_tags_before_session(self, mock_req): + agentops.add_tags(["pre-session-tag"]) + agentops.start_session(config=self.config) + + request_json = mock_req.last_request.json() + assert request_json["session"]["tags"] == ["pre-session-tag"] + + def test_set_tags_before_session(self, mock_req): + agentops.set_tags(["pre-session-tag"]) + agentops.start_session(config=self.config) + + request_json = mock_req.last_request.json() + assert request_json["session"]["tags"] == ["pre-session-tag"] + class TestMultiSessions: def setup_method(self):