diff --git a/src/agentscope/agents/agent.py b/src/agentscope/agents/agent.py index 7894345df..10e125528 100644 --- a/src/agentscope/agents/agent.py +++ b/src/agentscope/agents/agent.py @@ -8,6 +8,7 @@ from typing import Union from typing import Any from typing import Type +import json import uuid from loguru import logger @@ -70,6 +71,10 @@ def __call__(cls, *args: tuple, **kwargs: dict) -> Any: "lazy_launch", True, ), + upload_source_code=to_dist.pop( # type: ignore[arg-type] + "upload_source_code", + False, + ), agent_id=cls.generate_agent_id(), connect_existing=False, agent_class=cls, @@ -99,6 +104,7 @@ def __init__( max_timeout_seconds: int = 1800, local_mode: bool = True, lazy_launch: bool = True, + upload_source_code: bool = False, ): """Init the distributed configuration. @@ -116,6 +122,12 @@ def __init__( requests. lazy_launch (`bool`, defaults to `True`): Only launch the server when the agent is called. + upload_source_code (`bool`, defaults to `False`): + Upload the source code of the agent to the agent server. + Only takes effect when connecting to an existing server. + When you are using an agent that doens't exist on the server + (such as your customized agent that is not officially provided + by AgentScope), please set this value to `True`. """ self["host"] = host self["port"] = port @@ -123,6 +135,7 @@ def __init__( self["max_timeout_seconds"] = max_timeout_seconds self["local_mode"] = local_mode self["lazy_launch"] = lazy_launch + self["upload_source_code"] = upload_source_code class AgentBase(Operator, metaclass=_AgentMeta): @@ -358,6 +371,20 @@ def _broadcast_to_audience(self, x: dict) -> None: for agent in self._audience: agent.observe(x) + def __str__(self) -> str: + serialized_fields = { + "name": self.name, + "type": self.__class__.__name__, + "agent_id": self.agent_id, + } + if hasattr(self, "model"): + serialized_fields["model"] = { + "model_type": self.model.model_type, + "config_name": self.model.config_name, + "model_name": self.model.model_name, + } + return json.dumps(serialized_fields, ensure_ascii=False) + @property def agent_id(self) -> str: """The unique id of this agent. diff --git a/src/agentscope/rpc/rpc_agent_client.py b/src/agentscope/rpc/rpc_agent_client.py index 4a7769598..94d66ba87 100644 --- a/src/agentscope/rpc/rpc_agent_client.py +++ b/src/agentscope/rpc/rpc_agent_client.py @@ -226,7 +226,7 @@ def update_placeholder(self, task_id: int) -> str: ) return result_msg.value - def get_agent_id_list(self, agent_id: str) -> Sequence[str]: + def get_agent_id_list(self) -> Sequence[str]: """ Get id of all agents on the server as a list. @@ -235,9 +235,7 @@ def get_agent_id_list(self, agent_id: str) -> Sequence[str]: """ with grpc.insecure_channel(f"{self.host}:{self.port}") as channel: stub = RpcAgentStub(channel) - resp = stub.get_agent_id_list( - agent_pb2.AgentIds(agent_ids=[agent_id]), - ) + resp = stub.get_agent_id_list(Empty()) return resp.agent_ids def get_agent_info(self, agent_id: str = None) -> dict: diff --git a/src/agentscope/server/servicer.py b/src/agentscope/server/servicer.py index 37efd8a1b..42bf3b3f9 100644 --- a/src/agentscope/server/servicer.py +++ b/src/agentscope/server/servicer.py @@ -1,9 +1,12 @@ # -*- coding: utf-8 -*- """ Server of distributed agent""" +import os import threading import traceback +import json from concurrent import futures from loguru import logger +import psutil try: import dill @@ -71,6 +74,7 @@ def __init__( self.agent_id_lock = threading.Lock() self.task_id_counter = 0 self.agent_pool: dict[str, AgentBase] = {} + self.pid = os.getpid() def get_task_id(self) -> int: """Get the auto-increment task id. @@ -236,6 +240,50 @@ def update_placeholder( break return agent_pb2.RpcMsg(value=result.serialize()) + def get_agent_id_list( + self, + request: Empty, + context: ServicerContext, + ) -> agent_pb2.AgentIds: + """Get id of all agents on the server as a list.""" + with self.agent_id_lock: + agent_ids = self.agent_pool.keys() + return agent_pb2.AgentIds(agent_ids=agent_ids) + + def get_agent_info( + self, + request: agent_pb2.AgentIds, + context: ServicerContext, + ) -> agent_pb2.StatusResponse: + """Get the agent information of the specific agent_id""" + result = {} + with self.agent_id_lock: + for agent_id in request.agent_ids: + if agent_id in self.agent_pool: + result[agent_id] = str(self.agent_pool[agent_id]) + else: + logger.warning( + f"Getting info of a non-existent agent [{agent_id}].", + ) + return agent_pb2.StatusResponse( + ok=True, + message=json.dumps(result), + ) + + def get_server_info( + self, + request: Empty, + context: ServicerContext, + ) -> agent_pb2.StatusResponse: + """Get the agent server resource usage information.""" + status = {} + status["pid"] = self.pid + process = psutil.Process(self.pid) + status["CPU Times"] = process.cpu_times() + status["CPU Percent"] = process.cpu_percent() + status["Memory Usage"] = process.memory_info().rss + return agent_pb2.StatusResponse(ok=True, message=json.dumps(status)) + def _reply(self, request: agent_pb2.RpcMsg) -> agent_pb2.RpcMsg: """Call function of RpcAgentService diff --git a/tests/rpc_agent_test.py b/tests/rpc_agent_test.py index 4b3d6f832..d4d5f8d41 100644 --- a/tests/rpc_agent_test.py +++ b/tests/rpc_agent_test.py @@ -648,3 +648,36 @@ def reply(self, x: dict = None) -> dict: upload_source_code=True, ) launcher.shutdown() + + def test_agent_server_management_funcs(self) -> None: + """Test agent server management functions""" + launcher = RpcAgentServerLauncher( + host="localhost", + port=12010, + local_mode=False, + ) + launcher.launch() + client = RpcAgentClient(host="localhost", port=launcher.port) + agent_ids = client.get_agent_id_list() + self.assertEqual(len(agent_ids), 0) + agent = DemoRpcAgent( + name="demo", + to_dist={ + "host": "localhost", + "port": launcher.port, + "upload_source_code": True, + }, + ) + agent_ids = client.get_agent_id_list() + self.assertEqual(len(agent_ids), 1) + self.assertEqual(agent_ids[0], agent.agent_id) + agent_info = client.get_agent_info(agent_id=agent.agent_id) + logger.info(agent_info) + self.assertTrue(agent.agent_id in agent_info) + server_info = client.get_server_info() + logger.info(server_info) + self.assertTrue("pid" in server_info) + self.assertTrue("CPU Times" in server_info) + self.assertTrue("CPU Percent" in server_info) + self.assertTrue("Memory Usage" in server_info) + launcher.shutdown()