Skip to content

Commit

Permalink
update agent server management test
Browse files Browse the repository at this point in the history
  • Loading branch information
pan-x-c committed May 21, 2024
1 parent d82c4df commit 485537e
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 4 deletions.
27 changes: 27 additions & 0 deletions src/agentscope/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Union
from typing import Any
from typing import Type
import json
import uuid
from loguru import logger

Expand Down Expand Up @@ -70,6 +71,10 @@ def __call__(cls, *args: tuple, **kwargs: dict) -> Any:
"lazy_launch",
True,
),
upload_source_code=to_dist.pop( # type: ignore[arg-type]
"upload_source_code",
False,
),
agent_id=cls.generate_agent_id(),
connect_existing=False,
agent_class=cls,
Expand Down Expand Up @@ -99,6 +104,7 @@ def __init__(
max_timeout_seconds: int = 1800,
local_mode: bool = True,
lazy_launch: bool = True,
upload_source_code: bool = False,
):
"""Init the distributed configuration.
Expand All @@ -116,13 +122,20 @@ def __init__(
requests.
lazy_launch (`bool`, defaults to `True`):
Only launch the server when the agent is called.
upload_source_code (`bool`, defaults to `False`):
Upload the source code of the agent to the agent server.
Only takes effect when connecting to an existing server.
When you are using an agent that doens't exist on the server
(such as your customized agent that is not officially provided
by AgentScope), please set this value to `True`.
"""
self["host"] = host
self["port"] = port
self["max_pool_size"] = max_pool_size
self["max_timeout_seconds"] = max_timeout_seconds
self["local_mode"] = local_mode
self["lazy_launch"] = lazy_launch
self["upload_source_code"] = upload_source_code


class AgentBase(Operator, metaclass=_AgentMeta):
Expand Down Expand Up @@ -358,6 +371,20 @@ def _broadcast_to_audience(self, x: dict) -> None:
for agent in self._audience:
agent.observe(x)

def __str__(self) -> str:
serialized_fields = {
"name": self.name,
"type": self.__class__.__name__,
"agent_id": self.agent_id,
}
if hasattr(self, "model"):
serialized_fields["model"] = {
"model_type": self.model.model_type,
"config_name": self.model.config_name,
"model_name": self.model.model_name,
}
return json.dumps(serialized_fields, ensure_ascii=False)

@property
def agent_id(self) -> str:
"""The unique id of this agent.
Expand Down
6 changes: 2 additions & 4 deletions src/agentscope/rpc/rpc_agent_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def update_placeholder(self, task_id: int) -> str:
)
return result_msg.value

def get_agent_id_list(self, agent_id: str) -> Sequence[str]:
def get_agent_id_list(self) -> Sequence[str]:
"""
Get id of all agents on the server as a list.
Expand All @@ -235,9 +235,7 @@ def get_agent_id_list(self, agent_id: str) -> Sequence[str]:
"""
with grpc.insecure_channel(f"{self.host}:{self.port}") as channel:
stub = RpcAgentStub(channel)
resp = stub.get_agent_id_list(
agent_pb2.AgentIds(agent_ids=[agent_id]),
)
resp = stub.get_agent_id_list(Empty())
return resp.agent_ids

def get_agent_info(self, agent_id: str = None) -> dict:
Expand Down
48 changes: 48 additions & 0 deletions src/agentscope/server/servicer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# -*- coding: utf-8 -*-
""" Server of distributed agent"""
import os
import threading
import traceback
import json
from concurrent import futures
from loguru import logger
import psutil

try:
import dill
Expand Down Expand Up @@ -71,6 +74,7 @@ def __init__(
self.agent_id_lock = threading.Lock()
self.task_id_counter = 0
self.agent_pool: dict[str, AgentBase] = {}
self.pid = os.getpid()

def get_task_id(self) -> int:
"""Get the auto-increment task id.
Expand Down Expand Up @@ -236,6 +240,50 @@ def update_placeholder(
break
return agent_pb2.RpcMsg(value=result.serialize())

def get_agent_id_list(
self,
request: Empty,
context: ServicerContext,
) -> agent_pb2.AgentIds:
"""Get id of all agents on the server as a list."""
with self.agent_id_lock:
agent_ids = self.agent_pool.keys()
return agent_pb2.AgentIds(agent_ids=agent_ids)

def get_agent_info(
self,
request: agent_pb2.AgentIds,
context: ServicerContext,
) -> agent_pb2.StatusResponse:
"""Get the agent information of the specific agent_id"""
result = {}
with self.agent_id_lock:
for agent_id in request.agent_ids:
if agent_id in self.agent_pool:
result[agent_id] = str(self.agent_pool[agent_id])
else:
logger.warning(
f"Getting info of a non-existent agent [{agent_id}].",
)
return agent_pb2.StatusResponse(
ok=True,
message=json.dumps(result),
)

def get_server_info(
self,
request: Empty,
context: ServicerContext,
) -> agent_pb2.StatusResponse:
"""Get the agent server resource usage information."""
status = {}
status["pid"] = self.pid
process = psutil.Process(self.pid)
status["CPU Times"] = process.cpu_times()
status["CPU Percent"] = process.cpu_percent()
status["Memory Usage"] = process.memory_info().rss
return agent_pb2.StatusResponse(ok=True, message=json.dumps(status))

def _reply(self, request: agent_pb2.RpcMsg) -> agent_pb2.RpcMsg:
"""Call function of RpcAgentService
Expand Down
33 changes: 33 additions & 0 deletions tests/rpc_agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,3 +648,36 @@ def reply(self, x: dict = None) -> dict:
upload_source_code=True,
)
launcher.shutdown()

def test_agent_server_management_funcs(self) -> None:
"""Test agent server management functions"""
launcher = RpcAgentServerLauncher(
host="localhost",
port=12010,
local_mode=False,
)
launcher.launch()
client = RpcAgentClient(host="localhost", port=launcher.port)
agent_ids = client.get_agent_id_list()
self.assertEqual(len(agent_ids), 0)
agent = DemoRpcAgent(
name="demo",
to_dist={
"host": "localhost",
"port": launcher.port,
"upload_source_code": True,
},
)
agent_ids = client.get_agent_id_list()
self.assertEqual(len(agent_ids), 1)
self.assertEqual(agent_ids[0], agent.agent_id)
agent_info = client.get_agent_info(agent_id=agent.agent_id)
logger.info(agent_info)
self.assertTrue(agent.agent_id in agent_info)
server_info = client.get_server_info()
logger.info(server_info)
self.assertTrue("pid" in server_info)
self.assertTrue("CPU Times" in server_info)
self.assertTrue("CPU Percent" in server_info)
self.assertTrue("Memory Usage" in server_info)
launcher.shutdown()

0 comments on commit 485537e

Please sign in to comment.