From bc2e24ba6bab26108fd350fbbc1ad51472258ea5 Mon Sep 17 00:00:00 2001 From: DavdGao Date: Tue, 27 Aug 2024 10:40:48 +0800 Subject: [PATCH] Support to serialize message objects in AgentScope and remove unused arguments. (#388) --------- Co-authored-by: qbc --- .../en/source/tutorial/201-agent.md | 1 - .../zh_CN/source/tutorial/201-agent.md | 1 - .../conversation_moa.py | 4 - examples/conversation_nl2sql/sql_utils.py | 6 +- .../auto-discussion.py | 2 +- .../rag_example.py | 6 +- examples/conversation_with_mentions/main.py | 18 +- .../conversation_with_swe-agent/swe_agent.py | 4 +- .../answerer_agent.py | 5 +- .../searcher_agent.py | 4 +- examples/distributed_simulation/main.py | 6 +- .../distributed_simulation/participant.py | 4 +- .../src/alg_agents.py | 8 +- .../paper_llm_based_algorithm/src/counting.py | 2 +- examples/paper_llm_based_algorithm/src/rag.py | 6 +- .../src/retrieval.py | 2 +- .../paper_llm_based_algorithm/src/sorting.py | 4 +- src/agentscope/agents/agent.py | 26 +- src/agentscope/agents/dialog_agent.py | 14 +- src/agentscope/agents/dict_dialog_agent.py | 4 - src/agentscope/agents/rag_agent.py | 4 +- src/agentscope/agents/rpc_agent.py | 12 +- src/agentscope/agents/text_to_image_agent.py | 4 - src/agentscope/agents/user_agent.py | 1 - src/agentscope/logging.py | 4 +- src/agentscope/memory/memory.py | 20 - src/agentscope/memory/temporary_memory.py | 60 +-- src/agentscope/message/__init__.py | 8 +- src/agentscope/message/msg.py | 356 +++++++++++------- src/agentscope/message/placeholder.py | 340 ++++++++++------- src/agentscope/rpc/rpc_agent_client.py | 7 +- src/agentscope/serialize.py | 65 ++++ src/agentscope/server/servicer.py | 38 +- .../service/multi_modality/openai_services.py | 4 +- .../html-drag-components/message-msg.html | 5 + .../studio/static/js/workstation.js | 16 + .../static/workstation_templates/en4.json | 1 + src/agentscope/test.py | 0 src/agentscope/utils/tools.py | 4 +- tests/agent_test.py | 3 - tests/logger_test.py | 43 +-- tests/memory_test.py | 11 +- tests/message_test.py | 44 +++ tests/msghub_test.py | 10 +- tests/retrieval_from_list_test.py | 6 +- tests/rpc_agent_test.py | 58 +-- tests/serialize_test.py | 100 +++++ 47 files changed, 849 insertions(+), 502 deletions(-) create mode 100644 src/agentscope/serialize.py create mode 100644 src/agentscope/test.py create mode 100644 tests/message_test.py create mode 100644 tests/serialize_test.py diff --git a/docs/sphinx_doc/en/source/tutorial/201-agent.md b/docs/sphinx_doc/en/source/tutorial/201-agent.md index 3fa916a88..d28838497 100644 --- a/docs/sphinx_doc/en/source/tutorial/201-agent.md +++ b/docs/sphinx_doc/en/source/tutorial/201-agent.md @@ -35,7 +35,6 @@ class AgentBase(Operator): sys_prompt: Optional[str] = None, model_config_name: str = None, use_memory: bool = True, - memory_config: Optional[dict] = None, ) -> None: # ... [code omitted for brevity] diff --git a/docs/sphinx_doc/zh_CN/source/tutorial/201-agent.md b/docs/sphinx_doc/zh_CN/source/tutorial/201-agent.md index 10b29aeba..2e15490ad 100644 --- a/docs/sphinx_doc/zh_CN/source/tutorial/201-agent.md +++ b/docs/sphinx_doc/zh_CN/source/tutorial/201-agent.md @@ -36,7 +36,6 @@ class AgentBase(Operator): sys_prompt: Optional[str] = None, model_config_name: str = None, use_memory: bool = True, - memory_config: Optional[dict] = None, ) -> None: # ... [code omitted for brevity] diff --git a/examples/conversation_mixture_of_agents/conversation_moa.py b/examples/conversation_mixture_of_agents/conversation_moa.py index e1cc4260d..0dd7a613d 100644 --- a/examples/conversation_mixture_of_agents/conversation_moa.py +++ b/examples/conversation_mixture_of_agents/conversation_moa.py @@ -21,7 +21,6 @@ def __init__( name: str, moa_module: MixtureOfAgents, # changed to passing moa_module here use_memory: bool = True, - memory_config: Optional[dict] = None, ) -> None: """Initialize the dialog agent. @@ -35,14 +34,11 @@ def __init__( The inited MoA module you want to use as the main module. use_memory (`bool`, defaults to `True`): Whether the agent has memory. - memory_config (`Optional[dict]`): - The config of memory. """ super().__init__( name=name, sys_prompt="", use_memory=use_memory, - memory_config=memory_config, ) self.moa_module = moa_module # change model init to moa_module diff --git a/examples/conversation_nl2sql/sql_utils.py b/examples/conversation_nl2sql/sql_utils.py index 98f70ec36..5960b88f1 100644 --- a/examples/conversation_nl2sql/sql_utils.py +++ b/examples/conversation_nl2sql/sql_utils.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -Utils and helpers for performing sql querys. +Utils and helpers for performing sql queries. Referenced from https://github.com/BeachWang/DAIL-SQL. """ import sqlite3 @@ -261,11 +261,10 @@ def is_sql_question_prompt(self, question: str) -> str: } return self.sql_prompt.is_sql_question(target) - def generate_prompt(self, x: dict = None) -> dict: + def generate_prompt(self, question: str) -> dict: """ Generate prompt given input question """ - question = x["content"] target = { "path_db": self.db_path, "question": question, @@ -277,7 +276,6 @@ def generate_prompt(self, x: dict = None) -> dict: self.NUM_EXAMPLE * self.scope_factor, ) prompt_example = [] - question = target["question"] example_prefix = self.question_style.get_example_prefix() for example in examples: example_format = self.question_style.format_example(example) diff --git a/examples/conversation_self_organizing/auto-discussion.py b/examples/conversation_self_organizing/auto-discussion.py index 6470884be..8b44bc4df 100644 --- a/examples/conversation_self_organizing/auto-discussion.py +++ b/examples/conversation_self_organizing/auto-discussion.py @@ -55,7 +55,7 @@ x = Msg("user", x, role="user") settings = agent_builder(x) -scenario_participants = extract_scenario_and_participants(settings["content"]) +scenario_participants = extract_scenario_and_participants(settings.content) # set the agents that participant the discussion agents = [ diff --git a/examples/conversation_with_RAG_agents/rag_example.py b/examples/conversation_with_RAG_agents/rag_example.py index 283c014b2..9946cd888 100644 --- a/examples/conversation_with_RAG_agents/rag_example.py +++ b/examples/conversation_with_RAG_agents/rag_example.py @@ -127,15 +127,15 @@ def main() -> None: # 5. repeat x = user_agent() x.role = "user" # to enforce dashscope requirement on roles - if len(x["content"]) == 0 or str(x["content"]).startswith("exit"): + if len(x.content) == 0 or str(x.content).startswith("exit"): break - speak_list = filter_agents(x.get("content", ""), rag_agent_list) + speak_list = filter_agents(x.content, rag_agent_list) if len(speak_list) == 0: guide_response = guide_agent(x) # Only one agent can be called in the current version, # we may support multi-agent conversation later speak_list = filter_agents( - guide_response.get("content", ""), + guide_response.content, rag_agent_list, ) agent_name_list = [agent.name for agent in speak_list] diff --git a/examples/conversation_with_mentions/main.py b/examples/conversation_with_mentions/main.py index 94352adc9..d51616150 100644 --- a/examples/conversation_with_mentions/main.py +++ b/examples/conversation_with_mentions/main.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- """ A group chat where user can talk any time implemented by agentscope. """ -from loguru import logger from groupchat_utils import ( select_next_one, filter_agents, @@ -50,18 +49,11 @@ def main() -> None: speak_list = [] with msghub(agents, announcement=hint): while True: - try: - x = user(timeout=USER_TIME_TO_SPEAK) - if x.content == "exit": - break - except TimeoutError: - x = {"content": ""} - logger.info( - f"User has not typed text for " - f"{USER_TIME_TO_SPEAK} seconds, skip.", - ) - - speak_list += filter_agents(x.get("content", ""), npc_agents) + x = user(timeout=USER_TIME_TO_SPEAK) + if x.content == "exit": + break + + speak_list += filter_agents(x.content, npc_agents) if len(speak_list) > 0: next_agent = speak_list.pop(0) diff --git a/examples/conversation_with_swe-agent/swe_agent.py b/examples/conversation_with_swe-agent/swe_agent.py index 6d2c49424..f154c4865 100644 --- a/examples/conversation_with_swe-agent/swe_agent.py +++ b/examples/conversation_with_swe-agent/swe_agent.py @@ -197,7 +197,7 @@ def step(self) -> Msg: # parse and execute action action = res.parsed.get("action") - obs = self.prase_command(res.parsed["action"]) + obs = self.parse_command(res.parsed["action"]) self.speak( Msg(self.name, "\n====Observation====\n" + obs, role="assistant"), ) @@ -214,7 +214,7 @@ def reply(self, x: Optional[Union[Msg, Sequence[Msg]]] = None) -> Msg: action_name = msg.content["action"]["name"] return msg - def prase_command(self, command_call: dict) -> str: + def parse_command(self, command_call: dict) -> str: command_name = command_call["name"] command_args = command_call["arguments"] if command_name == "exit": diff --git a/examples/distributed_parallel_optimization/answerer_agent.py b/examples/distributed_parallel_optimization/answerer_agent.py index e44551d01..a5c87f9f3 100644 --- a/examples/distributed_parallel_optimization/answerer_agent.py +++ b/examples/distributed_parallel_optimization/answerer_agent.py @@ -37,6 +37,7 @@ def reply(self, x: Optional[Union[Msg, Sequence[Msg]]] = None) -> Msg: return Msg( self.name, content=f"Unable to load web page [{x.url}].", + role="assistant", url=x.url, ) # prepare prompt @@ -49,12 +50,12 @@ def reply(self, x: Optional[Union[Msg, Sequence[Msg]]] = None) -> Msg: " the following web page:\n\n" f"{response['html_to_text']}" f"\n\nBased on the above web page," - f" please answer my question\n{x.query}", + f" please answer my question\n{x.metadata}", ), ) # call llm and generate response response = self.model(prompt).text - msg = Msg(self.name, content=response, url=x.url) + msg = Msg(self.name, content=response, role="assistant", url=x.url) self.speak(msg) diff --git a/examples/distributed_parallel_optimization/searcher_agent.py b/examples/distributed_parallel_optimization/searcher_agent.py index eb1ad2f23..8e3f46a68 100644 --- a/examples/distributed_parallel_optimization/searcher_agent.py +++ b/examples/distributed_parallel_optimization/searcher_agent.py @@ -80,11 +80,13 @@ def reply(self, x: Optional[Union[Msg, Sequence[Msg]]] = None) -> Msg: Msg( name=self.name, content=result, + role="assistant", url=result["link"], - query=x.content, + metadata=x.content, ) for result in results ], + role="assistant", ) self.speak( Msg( diff --git a/examples/distributed_simulation/main.py b/examples/distributed_simulation/main.py index 93f5141a9..1ccad506e 100644 --- a/examples/distributed_simulation/main.py +++ b/examples/distributed_simulation/main.py @@ -187,10 +187,10 @@ def run_main_process( cnt = 0 for r in results: try: - summ += int(r["content"]["sum"]) - cnt += int(r["content"]["cnt"]) + summ += int(r.content["sum"]) + cnt += int(r.content["cnt"]) except Exception: - logger.error(r["content"]) + logger.error(r.content) et = time.time() logger.chat( Msg( diff --git a/examples/distributed_simulation/participant.py b/examples/distributed_simulation/participant.py index f017d4de3..3dc147a3f 100644 --- a/examples/distributed_simulation/participant.py +++ b/examples/distributed_simulation/participant.py @@ -36,7 +36,7 @@ def reply(self, x: Optional[Union[Msg, Sequence[Msg]]] = None) -> Msg: """Generate a random value""" # generate a response in content response = self.generate_random_response() - msg = Msg(self.name, content=response) + msg = Msg(self.name, content=response, role="assistant") return msg @@ -148,7 +148,7 @@ def reply(self, x: Optional[Union[Msg, Sequence[Msg]]] = None) -> Msg: summ = 0 for r in results: try: - summ += int(r["content"]) + summ += int(r.content) except Exception as e: print(e) return Msg( diff --git a/examples/paper_llm_based_algorithm/src/alg_agents.py b/examples/paper_llm_based_algorithm/src/alg_agents.py index 5004a9e49..8eb0cbb24 100644 --- a/examples/paper_llm_based_algorithm/src/alg_agents.py +++ b/examples/paper_llm_based_algorithm/src/alg_agents.py @@ -90,14 +90,14 @@ def invoke_llm_call( # Update relevant self.cost_metrics self.cost_metrics["llm_calls"] += 1 self.cost_metrics["prefilling_length_total"] += len( - x_request["content"], + x_request.content, ) + len(dialog_agent.sys_prompt) - self.cost_metrics["decoding_length_total"] += len(x["content"]) + self.cost_metrics["decoding_length_total"] += len(x.content) self.cost_metrics["prefilling_tokens_total"] += num_tokens_from_string( - x_request["content"], + x_request.content, ) + num_tokens_from_string(dialog_agent.sys_prompt) self.cost_metrics["decoding_tokens_total"] += num_tokens_from_string( - x["content"], + x.content, ) return x diff --git a/examples/paper_llm_based_algorithm/src/counting.py b/examples/paper_llm_based_algorithm/src/counting.py index 5df8c5538..2ff9fff6e 100644 --- a/examples/paper_llm_based_algorithm/src/counting.py +++ b/examples/paper_llm_based_algorithm/src/counting.py @@ -58,7 +58,7 @@ def solve_directly( for i in range(nsamples): x = self.invoke_llm_call(x_request, dialog_agent) candidate_solutions[i] = self.parse_llm_response_counting( - x["content"], + x.content, ) # int solution = max( diff --git a/examples/paper_llm_based_algorithm/src/rag.py b/examples/paper_llm_based_algorithm/src/rag.py index c37508ab4..173801402 100644 --- a/examples/paper_llm_based_algorithm/src/rag.py +++ b/examples/paper_llm_based_algorithm/src/rag.py @@ -134,7 +134,7 @@ def solve(self, request_string: str, question: str) -> dict: # ) # x_request = request_agent(x=None, content=content) # lst_x[i] = self.invoke_llm_call(x_request, dialog_agents[i]) - # sub_contents = [x["content"] for x in lst_x] + # sub_contents = [x.content for x in lst_x] # sub_solutions = ["" for _ in range(len(sub_requests))] # for i in range(len(sub_solutions)): # ss = self.parse_llm_response_retrieve_relevant_sentences( @@ -158,7 +158,7 @@ def solve(self, request_string: str, question: str) -> dict: x_request = request_agent(x=None, content=content) x = self.invoke_llm_call(x_request, dialog_agent) ss = self.parse_llm_response_retrieve_relevant_sentences( - x["content"], + x.content, ) sub_solutions[i] = ss sub_latencies[i] = time.time() - time_start @@ -183,7 +183,7 @@ def solve(self, request_string: str, question: str) -> dict: content = self.prompt_generate_final_answer(context, question) x_request = request_agent(x=None, content=content) x = self.invoke_llm_call(x_request, dialog_agent) - solution = self.parse_llm_response_generate_final_answer(x["content"]) + solution = self.parse_llm_response_generate_final_answer(x.content) final_step_latency = time.time() - time_start result = { diff --git a/examples/paper_llm_based_algorithm/src/retrieval.py b/examples/paper_llm_based_algorithm/src/retrieval.py index 1e857e20f..da0f77c43 100644 --- a/examples/paper_llm_based_algorithm/src/retrieval.py +++ b/examples/paper_llm_based_algorithm/src/retrieval.py @@ -84,7 +84,7 @@ def solve_directly( content = self.prompt_retrieval(request_string, question) x_request = request_agent(x=None, content=content) x = self.invoke_llm_call(x_request, dialog_agent) - solution = self.parse_llm_response_retrieval(x["content"]) + solution = self.parse_llm_response_retrieval(x.content) return solution def solve_decomposition(self, request_string: str, question: str) -> dict: diff --git a/examples/paper_llm_based_algorithm/src/sorting.py b/examples/paper_llm_based_algorithm/src/sorting.py index 18a42bca3..849f3f336 100644 --- a/examples/paper_llm_based_algorithm/src/sorting.py +++ b/examples/paper_llm_based_algorithm/src/sorting.py @@ -49,7 +49,7 @@ def solve_directly( content = self.prompt_sorting(request_string) x_request = request_agent(x=None, content=content) x = self.invoke_llm_call(x_request, dialog_agent) - solution = self.parse_llm_response_sorting(x["content"]) + solution = self.parse_llm_response_sorting(x.content) return solution def merge_two_sorted_lists( @@ -90,7 +90,7 @@ def merge_two_sorted_lists( content = self.prompt_merging(request_string) x_request = request_agent(x=None, content=content) x = self.invoke_llm_call(x_request, dialog_agent) - solution = self.parse_llm_response_sorting(x["content"]) + solution = self.parse_llm_response_sorting(x.content) return solution diff --git a/src/agentscope/agents/agent.py b/src/agentscope/agents/agent.py index 7ba872274..e176d6560 100644 --- a/src/agentscope/agents/agent.py +++ b/src/agentscope/agents/agent.py @@ -144,7 +144,6 @@ def __init__( sys_prompt: Optional[str] = None, model_config_name: str = None, use_memory: bool = True, - memory_config: Optional[dict] = None, to_dist: Optional[Union[DistConf, bool]] = False, ) -> None: r"""Initialize an agent from the given arguments. @@ -160,8 +159,6 @@ def __init__( configuration. use_memory (`bool`, defaults to `True`): Whether the agent has memory. - memory_config (`Optional[dict]`): - The config of memory. to_dist (`Optional[Union[DistConf, bool]]`, default to `False`): The configurations passed to :py:meth:`to_dist` method. Used in :py:class:`_AgentMeta`, when this parameter is provided, @@ -189,7 +186,6 @@ def __init__( See :doc:`Tutorial` for detail. """ self.name = name - self.memory_config = memory_config self.sys_prompt = sys_prompt # TODO: support to receive a ModelWrapper instance @@ -200,7 +196,7 @@ def __init__( ) if use_memory: - self.memory = TemporaryMemory(memory_config) + self.memory = TemporaryMemory() else: self.memory = None @@ -276,25 +272,7 @@ def reply(self, x: Optional[Union[Msg, Sequence[Msg]]] = None) -> Msg: f'"reply" function.', ) - def load_from_config(self, config: dict) -> None: - """Load configuration for this agent. - - Args: - config (`dict`): model configuration - """ - - def export_config(self) -> dict: - """Return configuration of this agent. - - Returns: - The configuration of current agent. - """ - return {} - - def load_memory(self, memory: Sequence[dict]) -> None: - r"""Load input memory.""" - - def __call__(self, *args: Any, **kwargs: Any) -> dict: + def __call__(self, *args: Any, **kwargs: Any) -> Msg: """Calling the reply function, and broadcast the generated response to all audiences if needed.""" res = self.reply(*args, **kwargs) diff --git a/src/agentscope/agents/dialog_agent.py b/src/agentscope/agents/dialog_agent.py index cb76f1354..031f0d2cc 100644 --- a/src/agentscope/agents/dialog_agent.py +++ b/src/agentscope/agents/dialog_agent.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- """A general dialog agent.""" -from typing import Optional, Union, Sequence +from typing import Optional, Union, Sequence, Any + +from loguru import logger from ..message import Msg from .agent import AgentBase @@ -16,7 +18,7 @@ def __init__( sys_prompt: str, model_config_name: str, use_memory: bool = True, - memory_config: Optional[dict] = None, + **kwargs: Any, ) -> None: """Initialize the dialog agent. @@ -31,17 +33,19 @@ def __init__( configuration. use_memory (`bool`, defaults to `True`): Whether the agent has memory. - memory_config (`Optional[dict]`): - The config of memory. """ super().__init__( name=name, sys_prompt=sys_prompt, model_config_name=model_config_name, use_memory=use_memory, - memory_config=memory_config, ) + if kwargs: + logger.warning( + f"Unused keyword arguments are provided: {kwargs}", + ) + def reply(self, x: Optional[Union[Msg, Sequence[Msg]]] = None) -> Msg: """Reply function of the agent. Processes the input data, generates a prompt using the current dialogue memory and system diff --git a/src/agentscope/agents/dict_dialog_agent.py b/src/agentscope/agents/dict_dialog_agent.py index 970a7a610..60fcc9e36 100644 --- a/src/agentscope/agents/dict_dialog_agent.py +++ b/src/agentscope/agents/dict_dialog_agent.py @@ -23,7 +23,6 @@ def __init__( sys_prompt: str, model_config_name: str, use_memory: bool = True, - memory_config: Optional[dict] = None, max_retries: Optional[int] = 3, ) -> None: """Initialize the dict dialog agent. @@ -39,8 +38,6 @@ def __init__( configuration. use_memory (`bool`, defaults to `True`): Whether the agent has memory. - memory_config (`Optional[dict]`, defaults to `None`): - The config of memory. max_retries (`Optional[int]`, defaults to `None`): The maximum number of retries when failed to parse the model output. @@ -50,7 +47,6 @@ def __init__( sys_prompt=sys_prompt, model_config_name=model_config_name, use_memory=use_memory, - memory_config=memory_config, ) self.parser = None diff --git a/src/agentscope/agents/rag_agent.py b/src/agentscope/agents/rag_agent.py index 63a23fdcd..ec5a8dc94 100644 --- a/src/agentscope/agents/rag_agent.py +++ b/src/agentscope/agents/rag_agent.py @@ -111,7 +111,7 @@ def reply(self, x: Optional[Union[Msg, Sequence[Msg]]] = None) -> Msg: ) query = ( "/n".join( - [msg["content"] for msg in history], + [msg.content for msg in history], ) if isinstance(history, list) else str(history) @@ -182,7 +182,7 @@ def reply(self, x: Optional[Union[Msg, Sequence[Msg]]] = None) -> Msg: # call llm and generate response response = self.model(prompt).text - msg = Msg(self.name, response) + msg = Msg(self.name, response, "assistant") # Print/speak the message in this agent's voice self.speak(msg) diff --git a/src/agentscope/agents/rpc_agent.py b/src/agentscope/agents/rpc_agent.py index 4a43b5f07..619898a91 100644 --- a/src/agentscope/agents/rpc_agent.py +++ b/src/agentscope/agents/rpc_agent.py @@ -3,12 +3,10 @@ from typing import Type, Optional, Union, Sequence from agentscope.agents.agent import AgentBase -from agentscope.message import ( - PlaceholderMessage, - serialize, - Msg, -) +from agentscope.message import Msg +from agentscope.message import PlaceholderMessage from agentscope.rpc import RpcAgentClient +from agentscope.serialize import serialize from agentscope.server.launcher import RpcAgentServerLauncher from agentscope.studio._client import _studio_client @@ -122,8 +120,6 @@ def reply(self, x: Optional[Union[Msg, Sequence[Msg]]] = None) -> Msg: if self.client is None: self._launch_server() return PlaceholderMessage( - name=self.name, - content=None, client=self.client, x=x, ) @@ -133,7 +129,7 @@ def observe(self, x: Union[Msg, Sequence[Msg]]) -> None: self._launch_server() self.client.call_agent_func( func_name="_observe", - value=serialize(x), # type: ignore[arg-type] + value=serialize(x), ) def clone_instances( diff --git a/src/agentscope/agents/text_to_image_agent.py b/src/agentscope/agents/text_to_image_agent.py index 00519a404..f66d75b32 100644 --- a/src/agentscope/agents/text_to_image_agent.py +++ b/src/agentscope/agents/text_to_image_agent.py @@ -21,7 +21,6 @@ def __init__( name: str, model_config_name: str, use_memory: bool = True, - memory_config: Optional[dict] = None, ) -> None: """Initialize the text to image agent. @@ -33,15 +32,12 @@ def __init__( configuration. use_memory (`bool`, defaults to `True`): Whether the agent has memory. - memory_config (`Optional[dict]`): - The config of memory. """ super().__init__( name=name, sys_prompt="", model_config_name=model_config_name, use_memory=use_memory, - memory_config=memory_config, ) logger.warning( diff --git a/src/agentscope/agents/user_agent.py b/src/agentscope/agents/user_agent.py index b76cf28d5..12b6a26b4 100644 --- a/src/agentscope/agents/user_agent.py +++ b/src/agentscope/agents/user_agent.py @@ -76,7 +76,6 @@ def reply( required_keys=required_keys, ) - print("Python: receive ", raw_input) content = raw_input["content"] url = raw_input["url"] kwargs = {} diff --git a/src/agentscope/logging.py b/src/agentscope/logging.py index 47d7709b7..163ba0577 100644 --- a/src/agentscope/logging.py +++ b/src/agentscope/logging.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- """Logging utilities.""" -import json import os import sys from typing import Optional, Literal, Any @@ -9,6 +8,7 @@ from .utils.tools import _guess_type_by_extension from .message import Msg +from .serialize import serialize from .studio._client import _studio_client from .web.gradio.utils import ( generate_image_from_name, @@ -99,7 +99,7 @@ def _save_msg(msg: Msg) -> None: logger.log( LEVEL_SAVE_MSG, - json.dumps(msg, ensure_ascii=False, default=lambda _: None), + serialize(msg), ) diff --git a/src/agentscope/memory/memory.py b/src/agentscope/memory/memory.py index bf457a3e5..de5430a2a 100644 --- a/src/agentscope/memory/memory.py +++ b/src/agentscope/memory/memory.py @@ -20,26 +20,6 @@ class MemoryBase(ABC): _version: int = 1 - def __init__( - self, - config: Optional[dict] = None, - ) -> None: - """MemoryBase is a base class for memory of agents. - - Args: - config (`Optional[dict]`, defaults to `None`): - Configuration of this memory. - """ - self.config = {} if config is None else config - - def update_config(self, config: dict) -> None: - """ - Configure memory as specified in config - Args: - config (`dict`): Configuration of resetting this memory - """ - self.config = config - @abstractmethod def get_memory( self, diff --git a/src/agentscope/memory/temporary_memory.py b/src/agentscope/memory/temporary_memory.py index 9e7b4aeba..d845a5523 100644 --- a/src/agentscope/memory/temporary_memory.py +++ b/src/agentscope/memory/temporary_memory.py @@ -14,15 +14,11 @@ from .memory import MemoryBase from ..manager import ModelManager +from ..serialize import serialize, deserialize from ..service.retrieval.retrieval_from_list import retrieve_from_list from ..service.retrieval.similarity import Embedding -from ..message import ( - deserialize, - serialize, - MessageBase, - Msg, - PlaceholderMessage, -) +from ..message import Msg +from ..message import PlaceholderMessage class TemporaryMemory(MemoryBase): @@ -32,20 +28,18 @@ class TemporaryMemory(MemoryBase): def __init__( self, - config: Optional[dict] = None, embedding_model: Union[str, Callable] = None, ) -> None: """ Temporary memory module for conversation. + Args: - config (dict): - configuration of the memory embedding_model (Union[str, Callable]) if the temporary memory needs to be embedded, then either pass the name of embedding model or the embedding model itself. """ - super().__init__(config) + super().__init__() self._content = [] @@ -63,7 +57,6 @@ def add( memories: Union[Sequence[Msg], Msg, None], embed: bool = False, ) -> None: - # pylint: disable=too-many-branches """ Adding new memory fragment, depending on how the memory are stored Args: @@ -80,29 +73,25 @@ def add( else: record_memories = memories - # if memory doesn't have id attribute, we skip the checking + # Assert the message types memories_idx = set(_.id for _ in self._content if hasattr(_, "id")) for memory_unit in record_memories: - if not issubclass(type(memory_unit), MessageBase): - try: - memory_unit = Msg(**memory_unit) - except Exception as exc: - raise ValueError( - f"Cannot add {memory_unit} to memory, " - f"must be with subclass of MessageBase", - ) from exc - # in case this is a PlaceholderMessage, try to update # the values first + # TODO: Unify PlaceholderMessage and Msg into one class to avoid + # type error if isinstance(memory_unit, PlaceholderMessage): memory_unit.update_value() - memory_unit = Msg(**memory_unit) + memory_unit = Msg.from_dict(memory_unit.to_dict()) + + if not isinstance(memory_unit, Msg): + raise ValueError( + f"Cannot add {type(memory_unit)} to memory, " + f"must be a Msg object.", + ) - # add to memory if it's new - if ( - not hasattr(memory_unit, "id") - or memory_unit.id not in memories_idx - ): + # Add to memory if it's new + if memory_unit.id not in memories_idx: if embed: if self.embedding_model: # TODO: embed only content or its string representation @@ -220,8 +209,21 @@ def load( e.doc, e.pos, ) - else: + elif isinstance(memories, list): + for unit in memories: + if not isinstance(unit, Msg): + raise TypeError( + f"Expect a list of Msg objects, but get {type(unit)} " + f"instead.", + ) load_memories = memories + elif isinstance(memories, Msg): + load_memories = [memories] + else: + raise TypeError( + f"The type of memories to be loaded is not supported. " + f"Expect str, list[Msg], or Msg, but get {type(memories)}.", + ) # overwrite the original memories after loading the new ones if overwrite: diff --git a/src/agentscope/message/__init__.py b/src/agentscope/message/__init__.py index f26315f3b..419526f87 100644 --- a/src/agentscope/message/__init__.py +++ b/src/agentscope/message/__init__.py @@ -1,12 +1,10 @@ # -*- coding: utf-8 -*- """The message module of AgentScope.""" -from .msg import Msg, MessageBase -from .placeholder import PlaceholderMessage, deserialize, serialize +from .msg import Msg +from .placeholder import PlaceholderMessage __all__ = [ "Msg", - "MessageBase", - "deserialize", - "serialize", + "PlaceholderMessage", ] diff --git a/src/agentscope/message/msg.py b/src/agentscope/message/msg.py index 7a62757c6..342e86dda 100644 --- a/src/agentscope/message/msg.py +++ b/src/agentscope/message/msg.py @@ -1,168 +1,207 @@ # -*- coding: utf-8 -*- +# mypy: disable-error-code="misc" """The base class for message unit""" - -from typing import Any, Optional, Union, Literal, List +from typing import ( + Any, + Literal, + Union, + List, + Optional, +) from uuid import uuid4 -import json from loguru import logger -from ..utils.tools import _get_timestamp, _map_string_to_color_mark +from ..serialize import is_serializable +from ..utils.tools import ( + _map_string_to_color_mark, + _get_timestamp, +) + +class Msg: + """The message class for AgentScope, which is responsible for storing + the information of a message, including -class MessageBase(dict): - """Base Message class, which is used to maintain information for dialog, - memory and used to construct prompt. + - id: the identity of the message + - name: who sends the message + - content: the message content + - role: the sender role chosen from 'system', 'user', 'assistant' + - url: the url(s) refers to multimodal content + - metadata: some additional information + - timestamp: when the message is created """ + __serialized_attrs: set = { + "id", + "name", + "content", + "role", + "url", + "metadata", + "timestamp", + } + """The attributes that need to be serialized and deserialized.""" + def __init__( self, name: str, content: Any, - role: Literal["user", "system", "assistant"] = "assistant", - url: Optional[Union[List[str], str]] = None, - timestamp: Optional[str] = None, + role: Union[str, Literal["system", "user", "assistant"]], + url: Optional[Union[str, List[str]]] = None, + metadata: Optional[Union[dict, str]] = None, + echo: bool = False, **kwargs: Any, ) -> None: - """Initialize the message object + """Initialize the message object. + + There are two ways to initialize a message object: + - Providing `name`, `content`, `role`, `url`(Optional), + `metadata`(Optional) to initialize a normal message object. + - Providing `host`, `port`, `task_id` to initialize a placeholder. + + Normally, users only need to create a normal message object by + providing `name`, `content`, `role`, `url`(Optional) and `metadata` + (Optional). + + The initialization of message has a high priority, which means that + when `name`, `content`, `role`, `host`, `port`, `task_id` are all + provided, the message will be initialized as a normal message object + rather than a placeholder. Args: name (`str`): - The name of who send the message. It's often used in - role-playing scenario to tell the name of the sender. + The name of who generates the message. content (`Any`): The content of the message. - role (`Literal["system", "user", "assistant"]`, defaults to "assistant"): - The role of who send the message. It can be one of the - `"system"`, `"user"`, or `"assistant"`. Default to - `"assistant"`. - url (`Optional[Union[List[str], str]]`, defaults to None): - A url to file, image, video, audio or website. - timestamp (`Optional[str]`, defaults to None): - The timestamp of the message, if None, it will be set to - current time. - **kwargs (`Any`): - Other attributes of the message. - """ # noqa - # id and timestamp will be added to the object as its attributes - # rather than items in dict - self.id = uuid4().hex - if timestamp is None: - self.timestamp = _get_timestamp() - else: - self.timestamp = timestamp + role (`Union[str, Literal["system", "user", "assistant"]]`): + The role of the message sender. + url (`Optional[Union[str, List[str]]`, defaults to `None`): + The url of the message. + metadata (`Optional[Union[dict, str]]`, defaults to `None`): + The additional information stored in the message. + echo (`bool`, defaults to `False`): + Whether to print the message when initializing the message obj. + """ + self.id = uuid4().hex self.name = name self.content = content self.role = role - self.url = url + self.metadata = metadata + self.timestamp = _get_timestamp() - self.update(kwargs) - - def __getattr__(self, key: Any) -> Any: - try: - return self[key] - except KeyError as e: - raise AttributeError(f"no attribute '{key}'") from e - - def __setattr__(self, key: Any, value: Any) -> None: - self[key] = value - - def __delattr__(self, key: Any) -> None: - try: - del self[key] - except KeyError as e: - raise AttributeError(f"no attribute '{key}'") from e - - def serialize(self) -> str: - """Return the serialized message.""" - raise NotImplementedError - - -class Msg(MessageBase): - """The Message class.""" - - id: str - """The id of the message.""" - - name: str - """The name of who send the message.""" - - content: Any - """The content of the message.""" - - role: Literal["system", "user", "assistant"] - """The role of the message sender.""" - - metadata: Optional[dict] - """Save the information for application's control flow, or other - purposes.""" + if kwargs: + logger.warning( + f"In current version, the message class in AgentScope does not" + f" inherit the dict class. " + f"The input arguments {kwargs} are not used.", + ) - url: Optional[Union[List[str], str]] - """A url to file, image, video, audio or website.""" + if echo: + logger.chat(self) - timestamp: str - """The timestamp of the message.""" + def __getitem__(self, item: str) -> Any: + """The getitem function, which will be deprecated in the new version""" + logger.warning( + f"The Msg class doesn't inherit dict any more. Please refer to " + f"its attribute by `msg.{item}` directly." + f"The support of __getitem__ will also be deprecated in the " + f"future.", + ) + return self.__getattribute__(item) - def __init__( - self, - name: str, - content: Any, - role: Literal["system", "user", "assistant"] = None, - url: Optional[Union[List[str], str]] = None, - timestamp: Optional[str] = None, - echo: bool = False, - metadata: Optional[Union[dict, str]] = None, - **kwargs: Any, - ) -> None: - """Initialize the message object + @property + def id(self) -> str: + """The identity of the message.""" + return self._id - Args: - name (`str`): - The name of who send the message. - content (`Any`): - The content of the message. - role (`Literal["system", "user", "assistant"]`): - Used to identify the source of the message, e.g. the system - information, the user input, or the model response. This - argument is used to accommodate most Chat API formats. - url (`Optional[Union[List[str], str]]`, defaults to `None`): - A url to file, image, video, audio or website. - timestamp (`Optional[str]`, defaults to `None`): - The timestamp of the message, if None, it will be set to - current time. - echo (`bool`, defaults to `False`): - Whether to print the message to the console. - metadata (`Optional[Union[dict, str]]`, defaults to `None`): - Save the information for application's control flow, or other - purposes. - **kwargs (`Any`): - Other attributes of the message. - """ + @property + def name(self) -> str: + """The name of the message sender.""" + return self._name - if role is None: + @property + def _colored_name(self) -> str: + """The name around with color marks, used to print in the terminal.""" + m1, m2 = _map_string_to_color_mark(self.name) + return f"{m1}{self.name}{m2}" + + @property + def content(self) -> Any: + """The content of the message.""" + return self._content + + @property + def role(self) -> Literal["system", "user", "assistant"]: + """The role of the message sender, chosen from 'system', 'user', + 'assistant'.""" + return self._role + + @property + def url(self) -> Optional[Union[str, List[str]]]: + """A URL string or a list of URL strings.""" + return self._url + + @property + def metadata(self) -> Optional[Union[dict, str]]: + """The metadata of the message, which can store some additional + information.""" + return self._metadata + + @property + def timestamp(self) -> str: + """The timestamp when the message is created.""" + return self._timestamp + + @id.setter # type: ignore[no-redef] + def id(self, value: str) -> None: + """Set the identity of the message.""" + self._id = value + + @name.setter # type: ignore[no-redef] + def name(self, value: str) -> None: + """Set the name of the message sender.""" + self._name = value + + @content.setter # type: ignore[no-redef] + def content(self, value: Any) -> None: + """Set the content of the message.""" + if not is_serializable(value): logger.warning( - "A new field `role` is newly added to the message. " - "Please specify the role of the message. Currently we use " - 'a default "assistant" value.', + f"The content of {type(value)} is not serializable, which " + f"may cause problems.", + ) + self._content = value + + @role.setter # type: ignore[no-redef] + def role(self, value: Literal["system", "user", "assistant"]) -> None: + """Set the role of the message sender. The role must be one of + 'system', 'user', 'assistant'.""" + if value not in ["system", "user", "assistant"]: + raise ValueError( + f"Invalid role {value}. The role must be one of " + f"['system', 'user', 'assistant']", ) + self._role = value - super().__init__( - name=name, - content=content, - role=role or "assistant", - url=url, - timestamp=timestamp, - metadata=metadata, - **kwargs, - ) + @url.setter # type: ignore[no-redef] + def url(self, value: Union[str, List[str], None]) -> None: + """Set the url of the message. The url can be a URL string or a list of + URL strings.""" + self._url = value - m1, m2 = _map_string_to_color_mark(self.name) - self._colored_name = f"{m1}{self.name}{m2}" + @metadata.setter # type: ignore[no-redef] + def metadata(self, value: Union[dict, str, None]) -> None: + """Set the metadata of the message to store some additional + information.""" + self._metadata = value - if echo: - logger.chat(self) + @timestamp.setter # type: ignore[no-redef] + def timestamp(self, value: str) -> None: + """Set the timestamp of the message.""" + self._timestamp = value def formatted_str(self, colored: bool = False) -> str: """Return the formatted string of the message. If the message has an @@ -171,6 +210,9 @@ def formatted_str(self, colored: bool = False) -> str: Args: colored (`bool`, defaults to `False`): Whether to color the name of the message + + Returns: + `str`: The formatted string of the message. """ if colored: name = self._colored_name @@ -186,5 +228,59 @@ def formatted_str(self, colored: bool = False) -> str: colored_strs.append(f"{name}: {self.url}") return "\n".join(colored_strs) - def serialize(self) -> str: - return json.dumps({"__type": "Msg", **self}) + def to_dict(self) -> dict: + """Serialize the message into a dictionary, which can be + deserialized by calling the `from_dict` function. + + Returns: + `dict`: The serialized dictionary. + """ + serialized_dict = { + "__module__": self.__class__.__module__, + "__name__": self.__class__.__name__, + } + + for attr_name in self.__serialized_attrs: + serialized_dict[attr_name] = getattr(self, f"_{attr_name}") + + return serialized_dict + + @classmethod + def from_dict(cls, serialized_dict: dict) -> "Msg": + """Deserialize the dictionary to a Msg object. + + Args: + serialized_dict (`dict`): + A dictionary that must contain the keys in + `Msg.__serialized_attrs`, and the keys `__module__` and + `__name__`. + + Returns: + `Msg`: A Msg object. + """ + assert set( + serialized_dict.keys(), + ) == cls.__serialized_attrs.union( + { + "__module__", + "__name__", + }, + ), ( + f"Expect keys {cls.__serialized_attrs}, but get " + f"{set(serialized_dict.keys())}", + ) + + assert serialized_dict.pop("__module__") == cls.__module__ + assert serialized_dict.pop("__name__") == cls.__name__ + + obj = cls( + name=serialized_dict["name"], + content=serialized_dict["content"], + role=serialized_dict["role"], + url=serialized_dict["url"], + metadata=serialized_dict["metadata"], + echo=False, + ) + obj.id = serialized_dict["id"] + obj.timestamp = serialized_dict["timestamp"] + return obj diff --git a/src/agentscope/message/placeholder.py b/src/agentscope/message/placeholder.py index 8420e74b8..73da3d231 100644 --- a/src/agentscope/message/placeholder.py +++ b/src/agentscope/message/placeholder.py @@ -1,19 +1,21 @@ # -*- coding: utf-8 -*- +# mypy: disable-error-code="misc" """The placeholder message for RpcAgent.""" -import json -from typing import Any, Optional, List, Union, Sequence +import os +from typing import Any, Optional, List, Union, Sequence, Literal from loguru import logger -from .msg import Msg, MessageBase +from .msg import Msg from ..rpc import RpcAgentClient, ResponseStub, call_in_thread -from ..utils.tools import is_web_accessible +from ..serialize import deserialize, is_serializable, serialize +from ..utils.tools import _is_web_url class PlaceholderMessage(Msg): """A placeholder for the return message of RpcAgent.""" - PLACEHOLDER_ATTRS = { + __placeholder_attrs = { "_host", "_port", "_client", @@ -22,44 +24,26 @@ class PlaceholderMessage(Msg): "_is_placeholder", } - LOCAL_ATTRS = { - "name", - "timestamp", - *PLACEHOLDER_ATTRS, + __serialized_attrs = { + "_host", + "_port", + "_task_id", } + _is_placeholder: bool + """Indicates whether the real message is still in the rpc server.""" + def __init__( self, - name: str, - content: Any, - url: Optional[Union[List[str], str]] = None, - timestamp: Optional[str] = None, host: str = None, port: int = None, task_id: int = None, client: Optional[RpcAgentClient] = None, - x: dict = None, - **kwargs: Any, + x: Optional[Union[Msg, Sequence[Msg]]] = None, ) -> None: """A placeholder message, records the address of the real message. Args: - name (`str`): - The name of who send the message. It's often used in - role-playing scenario to tell the name of the sender. - However, you can also only use `role` when calling openai api. - The usage of `name` refers to - https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models. - content (`Any`): - The content of the message. - role (`Literal["system", "user", "assistant"]`, defaults to "assistant"): - The role of the message, which can be one of the `"system"`, - `"user"`, or `"assistant"`. - url (`Optional[Union[List[str], str]]`, defaults to None): - A url to file, image, video, audio or website. - timestamp (`Optional[str]`, defaults to None): - The timestamp of the message, if None, it will be set to - current time. host (`str`, defaults to `None`): The hostname of the rpc server where the real message is located. @@ -70,15 +54,15 @@ def __init__( client (`RpcAgentClient`, defaults to `None`): An RpcAgentClient instance used to connect to the generator of this placeholder. - x (`dict`, defaults to `None`): + x (`Optional[Msg, Sequence[Msg]]`, defaults to `None`): Input parameters used to call rpc methods on the client. - """ # noqa + """ super().__init__( - name=name, - content=content, - url=url, - timestamp=timestamp, - **kwargs, + name="", + content="", + role="assistant", + url=None, + metadata=None, ) # placeholder indicates whether the real message is still in rpc server self._is_placeholder = True @@ -90,134 +74,232 @@ def __init__( else: self._stub = call_in_thread( client, - x.serialize() if x is not None else "", + serialize(x), "_reply", ) self._host = client.host self._port = client.port self._task_id = None - def __is_local(self, key: Any) -> bool: - return ( - key in PlaceholderMessage.LOCAL_ATTRS or not self._is_placeholder - ) + @property + def id(self) -> str: + """The identity of the message.""" + if self._is_placeholder: + self.update_value() + return self._id - def __getattr__(self, __name: str) -> Any: - """Get attribute value from PlaceholderMessage. Get value from rpc - agent server if necessary. + @property + def name(self) -> str: + """The name of the message sender.""" + if self._is_placeholder: + self.update_value() + return self._name - Args: - __name (`str`): - Attribute name. - """ - if not self.__is_local(__name): + @property + def content(self) -> Any: + """The content of the message.""" + if self._is_placeholder: + self.update_value() + return self._content + + @property + def role(self) -> Literal["system", "user", "assistant"]: + """The role of the message sender, chosen from 'system', 'user', + 'assistant'.""" + if self._is_placeholder: self.update_value() - return MessageBase.__getattr__(self, __name) + return self._role - def __getitem__(self, __key: Any) -> Any: - """Get item value from PlaceholderMessage. Get value from rpc - agent server if necessary. + @property + def url(self) -> Optional[Union[str, List[str]]]: + """A URL string or a list of URL strings.""" + if self._is_placeholder: + self.update_value() + return self._url - Args: - __key (`Any`): - Item name. - """ - if not self.__is_local(__key): + @property + def metadata(self) -> Optional[Union[dict, str]]: + """The metadata of the message, which can store some additional + information.""" + if self._is_placeholder: self.update_value() - return MessageBase.__getitem__(self, __key) + return self._metadata + + @property + def timestamp(self) -> str: + """The timestamp when the message is created.""" + if self._is_placeholder: + self.update_value() + return self._timestamp + + @id.setter # type: ignore[no-redef] + def id(self, value: str) -> None: + """Set the identity of the message.""" + self._id = value + + @name.setter # type: ignore[no-redef] + def name(self, value: str) -> None: + """Set the name of the message sender.""" + self._name = value + + @content.setter # type: ignore[no-redef] + def content(self, value: Any) -> None: + """Set the content of the message.""" + if not is_serializable(value): + logger.warning( + f"The content of {type(value)} is not serializable, which " + f"may cause problems.", + ) + self._content = value + + @role.setter # type: ignore[no-redef] + def role(self, value: Literal["system", "user", "assistant"]) -> None: + """Set the role of the message sender. The role must be one of + 'system', 'user', 'assistant'.""" + if value not in ["system", "user", "assistant"]: + raise ValueError( + f"Invalid role {value}. The role must be one of " + f"['system', 'user', 'assistant']", + ) + self._role = value - def update_value(self) -> MessageBase: + @url.setter # type: ignore[no-redef] + def url(self, value: Union[str, List[str], None]) -> None: + """Set the url of the message. The url can be a URL string or a list of + URL strings.""" + self._url = value + + @metadata.setter # type: ignore[no-redef] + def metadata(self, value: Union[dict, str, None]) -> None: + """Set the metadata of the message to store some additional + information.""" + self._metadata = value + + @timestamp.setter # type: ignore[no-redef] + def timestamp(self, value: str) -> None: + """Set the timestamp of the message.""" + self._timestamp = value + + def update_value(self) -> None: """Get attribute values from rpc agent server immediately""" if self._is_placeholder: # retrieve real message from rpc agent server self.__update_task_id() client = RpcAgentClient(self._host, self._port) result = client.update_placeholder(task_id=self._task_id) - msg = deserialize(result) - self.__update_url(msg) # type: ignore[arg-type] - self.update(msg) - # the actual value has been updated, not a placeholder anymore + + # Update the values according to the result obtained from the + # distributed agent + data = deserialize(result) + + self.id = data.id + self.name = data.name + self.role = data.role + self.content = data.content + self.metadata = data.metadata + + self.timestamp = data.timestamp + + # For url field, download the file if it's a local file of the + # distributed agent, and turn it into a local url + self.url = self.__update_url(data.url) + self._is_placeholder = False - return self - def __update_url(self, msg: MessageBase) -> None: - """Update the url field of the message.""" - if hasattr(msg, "url") and msg.url is None: - return - url = msg.url + def __update_url( + self, + url: Union[list[str], str, None], + ) -> Union[list, str, None]: + """If the url links to + - a file that the main process can access, return the url directly + - a web resource, return the url directly + - a local file of the distributed agent (maybe in the deployed + machine of the distributed agent), we download the file and update + the url to the local url. + - others (maybe a meaningless url, e.g "xxx.com"), return the url. + + Args: + url (`Union[List[str], str, None]`): + The url to be updated. + """ + + if url is None: + return None + if isinstance(url, str): - urls = [url] - else: - urls = url - checked_urls = [] - for url in urls: - if not is_web_accessible(url): - client = RpcAgentClient(self._host, self._port) - checked_urls.append(client.download_file(path=url)) - else: - checked_urls.append(url) - msg.url = checked_urls[0] if isinstance(url, str) else checked_urls + if os.path.exists(url) or _is_web_url(url): + return url + + # Try to get the file from the distributed agent + client = RpcAgentClient(self.host, self.port) + # TODO: what if failed here? + local_url = client.download_file(path=url) + + return local_url + + if isinstance(url, list): + return [self.__update_url(u) for u in url] + + raise TypeError( + f"Invalid URL type, expect str, list[str] or None, " + f"got {type(url)}.", + ) def __update_task_id(self) -> None: + """Get the task_id from the rpc server.""" if self._stub is not None: try: - resp = deserialize(self._stub.get_response()) + task_id = deserialize(self._stub.get_response()) except Exception as e: - logger.error( - f"Failed to get task_id: {self._stub.get_response()}", - ) raise ValueError( f"Failed to get task_id: {self._stub.get_response()}", ) from e - self._task_id = resp["task_id"] # type: ignore[call-overload] + self._task_id = task_id self._stub = None - def serialize(self) -> str: + def to_dict(self) -> dict: + """Serialize the placeholder message.""" if self._is_placeholder: self.__update_task_id() - return json.dumps( - { - "__type": "PlaceholderMessage", - "name": self.name, - "content": None, - "timestamp": self.timestamp, - "host": self._host, - "port": self._port, - "task_id": self._task_id, - }, - ) - else: - states = { - k: v - for k, v in self.items() - if k not in PlaceholderMessage.PLACEHOLDER_ATTRS - } - states["__type"] = "Msg" - return json.dumps(states) + # Serialize the placeholder message + serialized_dict = { + "__module__": self.__class__.__module__, + "__name__": self.__class__.__name__, + } -_MSGS = { - "Msg": Msg, - "PlaceholderMessage": PlaceholderMessage, -} + for attr_name in self.__serialized_attrs: + serialized_dict[attr_name] = getattr(self, attr_name) + return serialized_dict -def deserialize(s: Union[str, bytes]) -> Union[Msg, Sequence]: - """Deserialize json string into MessageBase""" - js_msg = json.loads(s) - msg_type = js_msg.pop("__type") - if msg_type == "List": - return [deserialize(s) for s in js_msg["__value"]] - elif msg_type not in _MSGS: - raise NotImplementedError( - f"Deserialization of {msg_type} is not supported.", - ) - return _MSGS[msg_type](**js_msg) + else: + # Serialize into a normal Msg object + serialized_dict = { + "__module__": Msg.__module__, + "__name__": Msg.__name__, + } + # TODO: We will merge the placeholder and message classes in the + # future to avoid the hard coding of the serialized attributes + # here + for attr_name in [ + "id", + "name", + "content", + "role", + "url", + "metadata", + "timestamp", + ]: + serialized_dict[attr_name] = getattr(self, attr_name) + return serialized_dict -def serialize(messages: Union[Sequence[MessageBase], MessageBase]) -> str: - """Serialize multiple MessageBase instance""" - if isinstance(messages, MessageBase): - return messages.serialize() - seq = [msg.serialize() for msg in messages] - return json.dumps({"__type": "List", "__value": seq}) + @classmethod + def from_dict(cls, serialized_dict: dict) -> "PlaceholderMessage": + """Create a PlaceholderMessage from a dictionary.""" + return cls( + host=serialized_dict["_host"], + port=serialized_dict["_port"], + task_id=serialized_dict["_task_id"], + ) diff --git a/src/agentscope/rpc/rpc_agent_client.py b/src/agentscope/rpc/rpc_agent_client.py index 878ba1613..4e4bdbe45 100644 --- a/src/agentscope/rpc/rpc_agent_client.py +++ b/src/agentscope/rpc/rpc_agent_client.py @@ -7,6 +7,9 @@ from typing import Optional, Sequence, Union, Generator from loguru import logger +from ..message import Msg +from ..serialize import deserialize + try: import dill import grpc @@ -304,7 +307,7 @@ def set_model_configs( return False return True - def get_agent_memory(self, agent_id: str) -> Union[list, dict]: + def get_agent_memory(self, agent_id: str) -> Union[list[Msg], Msg]: """Get the memory usage of the specific agent.""" with grpc.insecure_channel(f"{self.host}:{self.port}") as channel: stub = RpcAgentStub(channel) @@ -313,7 +316,7 @@ def get_agent_memory(self, agent_id: str) -> Union[list, dict]: ) if not resp.ok: logger.error(f"Error in get_agent_memory: {resp.message}") - return json.loads(resp.message) + return deserialize(resp.message) def download_file(self, path: str) -> str: """Download a file from a remote server to the local machine. diff --git a/src/agentscope/serialize.py b/src/agentscope/serialize.py new file mode 100644 index 000000000..bef8dd8f5 --- /dev/null +++ b/src/agentscope/serialize.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +"""The serialization module for the package.""" +import importlib +import json +from typing import Any + + +def _default_serialize(obj: Any) -> Any: + """Serialize the object when `json.dumps` cannot handle it.""" + if hasattr(obj, "__module__") and hasattr(obj, "__class__"): + # To avoid circular import, we hard code the module name here + if ( + obj.__module__ == "agentscope.message.msg" + and obj.__class__.__name__ == "Msg" + ): + return obj.to_dict() + + if ( + obj.__module__ == "agentscope.message.placeholder" + and obj.__class__.__name__ == "PlaceholderMessage" + ): + return obj.to_dict() + + return obj + + +def _deserialize_hook(data: dict) -> Any: + """Deserialize the JSON string to an object, including Msg object in + AgentScope.""" + module_name = data.get("__module__", None) + class_name = data.get("__name__", None) + + if module_name is not None and class_name is not None: + module = importlib.import_module(module_name) + cls = getattr(module, class_name) + if hasattr(cls, "from_dict"): + return cls.from_dict(data) + return data + + +def serialize(obj: Any) -> str: + """Serialize the object to a JSON string. + + For AgentScope, this function supports to serialize `Msg` object for now. + """ + # TODO: We leave the serialization of agents in next PR + return json.dumps(obj, ensure_ascii=False, default=_default_serialize) + + +def deserialize(s: str) -> Any: + """Deserialize the JSON string to an object + + For AgentScope, this function supports to serialize `Msg` object for now. + """ + # TODO: We leave the serialization of agents in next PR + return json.loads(s, object_hook=_deserialize_hook) + + +def is_serializable(obj: Any) -> bool: + """Check if the object is serializable in the scope of AgentScope.""" + try: + serialize(obj) + return True + except Exception: + return False diff --git a/src/agentscope/server/servicer.py b/src/agentscope/server/servicer.py index e047b8fc4..1404d9adc 100644 --- a/src/agentscope/server/servicer.py +++ b/src/agentscope/server/servicer.py @@ -31,17 +31,14 @@ ExpiringDict = ImportErrorReporter(import_error, "distribute") import agentscope.rpc.rpc_agent_pb2 as agent_pb2 +from agentscope.serialize import deserialize, serialize from agentscope.agents.agent import AgentBase from agentscope.manager import ModelManager from agentscope.manager import ASManager from agentscope.studio._client import _studio_client from agentscope.exception import StudioRegisterError from agentscope.rpc.rpc_agent_pb2_grpc import RpcAgentServicer -from agentscope.message import ( - Msg, - PlaceholderMessage, - deserialize, -) +from agentscope.message import Msg, PlaceholderMessage def _register_server_to_studio( @@ -347,7 +344,7 @@ def update_placeholder( else: return agent_pb2.GeneralResponse( ok=True, - message=result.serialize(), + message=serialize(result), ) def get_agent_list( @@ -362,7 +359,8 @@ def get_agent_list( summaries.append(str(agent)) return agent_pb2.GeneralResponse( ok=True, - message=json.dumps(summaries), + # TODO: unified into serialize function to avoid error. + message=serialize(summaries), ) def get_server_info( @@ -378,7 +376,7 @@ def get_server_info( status["cpu"] = process.cpu_percent(interval=1) status["mem"] = process.memory_info().rss / (1024**2) status["size"] = len(self.agent_pool) - return agent_pb2.GeneralResponse(ok=True, message=json.dumps(status)) + return agent_pb2.GeneralResponse(ok=True, message=serialize(status)) def set_model_configs( self, @@ -416,7 +414,7 @@ def get_agent_memory( ) return agent_pb2.GeneralResponse( ok=True, - message=json.dumps(agent.memory.get_memory()), + message=serialize(agent.memory.get_memory()), ) def download_file( @@ -465,11 +463,7 @@ def _reply(self, request: agent_pb2.RpcMsg) -> agent_pb2.GeneralResponse: ) return agent_pb2.GeneralResponse( ok=True, - message=Msg( # type: ignore[arg-type] - name=self.get_agent(request.agent_id).name, - content=None, - task_id=task_id, - ).serialize(), + message=str(task_id), ) def _observe(self, request: agent_pb2.RpcMsg) -> agent_pb2.GeneralResponse: @@ -483,9 +477,13 @@ def _observe(self, request: agent_pb2.RpcMsg) -> agent_pb2.GeneralResponse: `RpcMsg`: Empty RpcMsg. """ msgs = deserialize(request.value) - for msg in msgs: - if isinstance(msg, PlaceholderMessage): - msg.update_value() + if isinstance(msgs, list): + for msg in msgs: + if isinstance(msg, PlaceholderMessage): + msg.update_value() + elif isinstance(msgs, PlaceholderMessage): + msgs.update_value() + self.agent_pool[request.agent_id].observe(msgs) return agent_pb2.GeneralResponse(ok=True) @@ -493,14 +491,14 @@ def _process_messages( self, task_id: int, agent_id: str, - task_msg: dict = None, + task_msg: Msg = None, ) -> None: """Processing an input message and generate its reply message. Args: - task_id (`int`): task id of the input message, . + task_id (`int`): task id of the input message. agent_id (`str`): the id of the agent that accepted the message. - task_msg (`dict`): the input message. + task_msg (`Msg`): the input message. """ if isinstance(task_msg, PlaceholderMessage): task_msg.update_value() diff --git a/src/agentscope/service/multi_modality/openai_services.py b/src/agentscope/service/multi_modality/openai_services.py index 7e2acba91..16aca5a58 100644 --- a/src/agentscope/service/multi_modality/openai_services.py +++ b/src/agentscope/service/multi_modality/openai_services.py @@ -27,7 +27,7 @@ from agentscope.utils.tools import _download_file -from agentscope.message import MessageBase +from agentscope.message import Msg def _url_to_filename(url: str) -> str: @@ -420,7 +420,7 @@ def openai_image_to_text( model_name=model, api_key=api_key, ) - messages = MessageBase( + messages = Msg( name="service_call", role="user", content=prompt, diff --git a/src/agentscope/studio/static/html-drag-components/message-msg.html b/src/agentscope/studio/static/html-drag-components/message-msg.html index ca29eef48..9c7a10d55 100644 --- a/src/agentscope/studio/static/html-drag-components/message-msg.html +++ b/src/agentscope/studio/static/html-drag-components/message-msg.html @@ -16,6 +16,11 @@ data-required="true">
+ + +
+ diff --git a/src/agentscope/studio/static/js/workstation.js b/src/agentscope/studio/static/js/workstation.js index fceaac016..3323b55b1 100644 --- a/src/agentscope/studio/static/js/workstation.js +++ b/src/agentscope/studio/static/js/workstation.js @@ -569,6 +569,7 @@ async function addNodeToDrawFlow(name, pos_x, pos_y) { "args": { "name": '', + "role": '', "content": '', "url": '' } @@ -1326,6 +1327,21 @@ function checkConditions() { isApiKeyEmpty = isApiKeyEmpty || true; } } + + if (node.name === "Message") { + const validRoles = ["system", "assistant", "user"]; + if (!validRoles.includes(node.data.args.role)) { + Swal.fire({ + title: 'Invalid Role for Message', + html: + `Invalid role ${node.data.args.role}.
The role must be in ['system', 'user', 'assistant']`, + icon: 'error', + confirmButtonText: 'Ok' + }); + return false; + } + } + if (node.name.includes('Agent') && "model_config_name" in node.data.args) { hasAgentError = false; if (node.data && node.data.args) { diff --git a/src/agentscope/studio/static/workstation_templates/en4.json b/src/agentscope/studio/static/workstation_templates/en4.json index ddb39b327..0fcb35a2d 100644 --- a/src/agentscope/studio/static/workstation_templates/en4.json +++ b/src/agentscope/studio/static/workstation_templates/en4.json @@ -213,6 +213,7 @@ "data": { "args": { "name": "User", + "role": "user", "content": "Hello every one", "url": "" } diff --git a/src/agentscope/test.py b/src/agentscope/test.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/agentscope/utils/tools.py b/src/agentscope/utils/tools.py index 4e2382fc0..ab060e44d 100644 --- a/src/agentscope/utils/tools.py +++ b/src/agentscope/utils/tools.py @@ -255,7 +255,7 @@ def generate_id_from_seed(seed: str, length: int = 8) -> str: return "".join(id_chars) -def is_web_accessible(url: str) -> bool: +def _is_web_url(url: str) -> bool: """Whether the url is accessible from the Web. Args: @@ -466,7 +466,7 @@ def _map_string_to_color_mark( ("\033[97m", "\033[0m"), ] - hash_value = hash(target_str) + hash_value = int(hashlib.sha256(target_str.encode()).hexdigest(), 16) index = hash_value % len(color_marks) return color_marks[index] diff --git a/tests/agent_test.py b/tests/agent_test.py index 0d3ff1d91..629e69d7c 100644 --- a/tests/agent_test.py +++ b/tests/agent_test.py @@ -26,9 +26,6 @@ def __init__( use_memory=( kwargs["use_memory"] if "use_memory" in kwargs else None ), - memory_config=( - kwargs["memory_config"] if "memory_config" in kwargs else None - ), ) diff --git a/tests/logger_test.py b/tests/logger_test.py index 1cc684b89..762b0d697 100644 --- a/tests/logger_test.py +++ b/tests/logger_test.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """ Unit test for logger chat""" +import json import os import shutil import time @@ -29,13 +30,11 @@ def test_logger_chat(self) -> None: msg1 = Msg("abc", "def", "assistant") msg1.id = 1 msg1.timestamp = 1 - msg1._colored_name = "1" # pylint: disable=protected-access # url msg2 = Msg("abc", "def", "assistant", url="https://xxx.png") msg2.id = 2 msg2.timestamp = 2 - msg2._colored_name = "2" # pylint: disable=protected-access # urls msg3 = Msg( @@ -46,13 +45,11 @@ def test_logger_chat(self) -> None: ) msg3.id = 3 msg3.timestamp = 3 - msg3._colored_name = "3" # pylint: disable=protected-access # html labels msg4 = Msg("Bob", "abc None: ) as file: lines = file.readlines() - ground_truth = [ - '{"id": 1, "timestamp": 1, "name": "abc", "content": "def", ' - '"role": "assistant", "url": null, "metadata": null, ' - '"_colored_name": "1"}\n', - '{"id": 2, "timestamp": 2, "name": "abc", "content": "def", ' - '"role": "assistant", "url": "https://xxx.png", "metadata": null, ' - '"_colored_name": "2"}\n', - '{"id": 3, "timestamp": 3, "name": "abc", "content": "def", ' - '"role": "assistant", "url": ' - '["https://yyy.png", "https://xxx.png"], "metadata": null, ' - '"_colored_name": "3"}\n', - '{"id": 4, "timestamp": 4, "name": "Bob", "content": ' - '"abcabc None: """Tear down for LoggerTest.""" diff --git a/tests/memory_test.py b/tests/memory_test.py index 55e02c109..8a3fdbfd0 100644 --- a/tests/memory_test.py +++ b/tests/memory_test.py @@ -9,6 +9,7 @@ from agentscope.message import Msg from agentscope.memory import TemporaryMemory +from agentscope.serialize import serialize class TemporaryMemoryTest(unittest.TestCase): @@ -80,7 +81,8 @@ def test_invalid(self) -> None: with self.assertRaises(Exception) as context: self.memory.add(self.invalid) self.assertTrue( - f"Cannot add {self.invalid} to memory" in str(context.exception), + f"Cannot add {type(self.invalid)} to memory, must be a Msg object." + in str(context.exception), ) def test_load_export(self) -> None: @@ -88,10 +90,11 @@ def test_load_export(self) -> None: Test load and export function of TemporaryMemory """ memory = TemporaryMemory() - user_input = Msg(name="user", content="Hello") + user_input = Msg(name="user", content="Hello", role="user") agent_input = Msg( name="agent", content="Hello! How can I help you?", + role="assistant", ) memory.load([user_input, agent_input]) retrieved_mem = memory.export(to_mem=True) @@ -108,8 +111,8 @@ def test_load_export(self) -> None: ) memory.load(self.file_name_1) self.assertEqual( - memory.get_memory(), - [user_input, agent_input], + serialize(memory.get_memory()), + serialize([user_input, agent_input]), ) diff --git a/tests/message_test.py b/tests/message_test.py new file mode 100644 index 000000000..7612842e6 --- /dev/null +++ b/tests/message_test.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +"""The unit test for message module.""" + +import unittest + +from agentscope.message import Msg + + +class MessageTest(unittest.TestCase): + """The test cases for message module.""" + + def test_msg(self) -> None: + """Test the basic attributes in Msg object.""" + msg = Msg(name="A", content="B", role="assistant") + self.assertEqual(msg.name, "A") + self.assertEqual(msg.content, "B") + self.assertEqual(msg.role, "assistant") + self.assertEqual(msg.metadata, None) + self.assertEqual(msg.url, None) + + def test_formatted_msg(self) -> None: + """Test the formatted message.""" + msg = Msg(name="A", content="B", role="assistant") + self.assertEqual( + msg.formatted_str(), + "A: B", + ) + self.assertEqual( + msg.formatted_str(colored=True), + "\x1b[95mA\x1b[0m: B", + ) + + def test_serialize(self) -> None: + """Test the serialization and deserialization of Msg object.""" + msg = Msg(name="A", content="B", role="assistant") + serialized_msg = msg.to_dict() + deserialized_msg = Msg.from_dict(serialized_msg) + self.assertEqual(msg.id, deserialized_msg.id) + self.assertEqual(msg.name, deserialized_msg.name) + self.assertEqual(msg.content, deserialized_msg.content) + self.assertEqual(msg.role, deserialized_msg.role) + self.assertEqual(msg.metadata, deserialized_msg.metadata) + self.assertEqual(msg.url, deserialized_msg.url) + self.assertEqual(msg.timestamp, deserialized_msg.timestamp) diff --git a/tests/msghub_test.py b/tests/msghub_test.py index 9859c364e..b5adadb25 100644 --- a/tests/msghub_test.py +++ b/tests/msghub_test.py @@ -34,10 +34,10 @@ def setUp(self) -> None: def test_msghub_operation(self) -> None: """Test add, delete and broadcast operations""" - msg1 = Msg(name="a1", content="msg1") - msg2 = Msg(name="a2", content="msg2") - msg3 = Msg(name="a3", content="msg3") - msg4 = Msg(name="a4", content="msg4") + msg1 = Msg(name="a1", content="msg1", role="assistant") + msg2 = Msg(name="a2", content="msg2", role="assistant") + msg3 = Msg(name="a3", content="msg3", role="assistant") + msg4 = Msg(name="a4", content="msg4", role="assistant") with msghub(participants=[self.agent1, self.agent2]) as hub: self.agent1(msg1) @@ -73,7 +73,7 @@ def test_msghub(self) -> None: name="w1", content="This secret that my password is 123456 can't be" " leaked!", - role="wisper", + role="assistant", ), ] diff --git a/tests/retrieval_from_list_test.py b/tests/retrieval_from_list_test.py index 52b30720b..f42529e3d 100644 --- a/tests/retrieval_from_list_test.py +++ b/tests/retrieval_from_list_test.py @@ -6,7 +6,7 @@ from agentscope.service import retrieve_from_list, cos_sim from agentscope.service.service_status import ServiceExecStatus -from agentscope.message import MessageBase, Msg +from agentscope.message import Msg from agentscope.memory.temporary_memory import TemporaryMemory from agentscope.models import OpenAIEmbeddingWrapper, ModelResponse @@ -40,11 +40,11 @@ def __call__(self, *args: Any, **kwargs: Any) -> ModelResponse: m2 = Msg(name="env", content="test2", role="assistant") m2.embedding = [0.5, 0.5] m2.timestamp = "2023-12-18 21:50:59" - memory = TemporaryMemory(config={}, embedding_model=dummy_model) + memory = TemporaryMemory(embedding_model=dummy_model) memory.add(m1) memory.add(m2) - def score_func(m1: MessageBase, m2: MessageBase) -> float: + def score_func(m1: Msg, m2: Msg) -> float: relevance = cos_sim(m1.embedding, m2.embedding).content time_gap = ( datetime.strptime(m1.timestamp, "%Y-%m-%d %H:%M:%S") diff --git a/tests/rpc_agent_test.py b/tests/rpc_agent_test.py index 0c62f9718..ab3673124 100644 --- a/tests/rpc_agent_test.py +++ b/tests/rpc_agent_test.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +# pylint: disable=W0212 """ Unit tests for rpc agent classes """ @@ -14,10 +15,10 @@ import agentscope from agentscope.agents import AgentBase, DistConf, DialogAgent from agentscope.manager import MonitorManager, ASManager +from agentscope.serialize import deserialize, serialize from agentscope.server import RpcAgentServerLauncher from agentscope.message import Msg from agentscope.message import PlaceholderMessage -from agentscope.message import deserialize from agentscope.msghub import msghub from agentscope.pipelines import sequentialpipeline from agentscope.rpc.rpc_agent_client import RpcAgentClient @@ -202,35 +203,34 @@ def test_single_rpc_agent_server(self) -> None: role="system", ) result = agent_a(msg) - # get name without waiting for the server - self.assertEqual(result.name, "a") - self.assertEqual(result["name"], "a") - js_placeholder_result = result.serialize() - self.assertTrue(result._is_placeholder) # pylint: disable=W0212 + + # The deserialization without accessing the attributes will generate + # a PlaceholderMessage instance. + js_placeholder_result = serialize(result) placeholder_result = deserialize(js_placeholder_result) self.assertTrue(isinstance(placeholder_result, PlaceholderMessage)) - self.assertEqual(placeholder_result.name, "a") - self.assertEqual( - placeholder_result["name"], # type: ignore[call-overload] - "a", - ) - self.assertTrue( - placeholder_result._is_placeholder, # pylint: disable=W0212 - ) + + # Fetch the attribute from distributed agent + self.assertTrue(result._is_placeholder) + self.assertEqual(result.name, "System") + self.assertFalse(result._is_placeholder) + # wait to get content self.assertEqual(result.content, msg.content) - self.assertFalse(result._is_placeholder) # pylint: disable=W0212 self.assertEqual(result.id, 0) + + # The second time to fetch the attributes from the distributed agent self.assertTrue( - placeholder_result._is_placeholder, # pylint: disable=W0212 + placeholder_result._is_placeholder, ) self.assertEqual(placeholder_result.content, msg.content) self.assertFalse( - placeholder_result._is_placeholder, # pylint: disable=W0212 + placeholder_result._is_placeholder, ) self.assertEqual(placeholder_result.id, 0) + # check msg - js_msg_result = result.serialize() + js_msg_result = serialize(result) msg_result = deserialize(js_msg_result) self.assertTrue(isinstance(msg_result, Msg)) self.assertEqual(msg_result.content, msg.content) @@ -250,7 +250,7 @@ def test_connect_to_an_existing_rpc_server(self) -> None: ) launcher.launch() client = RpcAgentClient(host=launcher.host, port=launcher.port) - self.assertTrue(client.is_alive()) # pylint: disable=W0212 + self.assertTrue(client.is_alive()) agent_a = DemoRpcAgent( name="a", ).to_dist( @@ -264,7 +264,7 @@ def test_connect_to_an_existing_rpc_server(self) -> None: ) result = agent_a(msg) # get name without waiting for the server - self.assertEqual(result.name, "a") + self.assertEqual(result.name, "System") # waiting for server self.assertEqual(result.content, msg.content) # test dict usage @@ -275,9 +275,9 @@ def test_connect_to_an_existing_rpc_server(self) -> None: ) result = agent_a(msg) # get name without waiting for the server - self.assertEqual(result["name"], "a") + self.assertEqual(result.name, "System") # waiting for server - self.assertEqual(result["content"], msg.content) + self.assertEqual(result.content, msg.content) # test to_str msg = Msg( name="System", @@ -285,7 +285,7 @@ def test_connect_to_an_existing_rpc_server(self) -> None: role="system", ) result = agent_a(msg) - self.assertEqual(result.formatted_str(), "a: {'text': 'test'}") + self.assertEqual(result.formatted_str(), "System: {'text': 'test'}") launcher.shutdown() def test_multi_rpc_agent(self) -> None: @@ -436,7 +436,7 @@ def test_multi_agent_in_same_server(self) -> None: host="127.0.0.1", port=launcher.port, ) - agent3._agent_id = agent1.agent_id # pylint: disable=W0212 + agent3._agent_id = agent1.agent_id agent3.client.agent_id = agent1.client.agent_id msg1 = Msg( name="System", @@ -474,7 +474,7 @@ def test_multi_agent_in_same_server(self) -> None: role="system", ) res2 = agent2(msg2) - self.assertRaises(ValueError, res2.__getattr__, "content") + self.assertRaises(ValueError, res2.update_value) # should override remote default parameter(e.g. name field) agent4 = DemoRpcAgentWithMemory( @@ -557,7 +557,7 @@ def test_error_handling(self) -> None: """Test error handling""" agent = DemoErrorAgent(name="a").to_dist() x = agent() - self.assertRaises(AgentCallError, x.__getattr__, "content") + self.assertRaises(AgentCallError, x.update_value) def test_agent_nesting(self) -> None: """Test agent nesting""" @@ -642,8 +642,8 @@ def test_agent_server_management_funcs(self) -> None: resp.update_value() memory = client.get_agent_memory(memory_agent.agent_id) self.assertEqual(len(memory), 2) - self.assertEqual(memory[0]["content"], "first msg") - self.assertEqual(memory[1]["content"]["mem_size"], 1) + self.assertEqual(memory[0].content, "first msg") + self.assertEqual(memory[1].content["mem_size"], 1) agent_lists = client.get_agent_list() self.assertEqual(len(agent_lists), 1) self.assertEqual(agent_lists[0]["agent_id"], memory_agent.agent_id) @@ -669,7 +669,7 @@ def test_agent_server_management_funcs(self) -> None: ), ) local_file_path = file.url - self.assertNotEqual(remote_file_path, local_file_path) + self.assertEqual(remote_file_path, local_file_path) with open(remote_file_path, "rb") as rf: remote_content = rf.read() with open(local_file_path, "rb") as lf: diff --git a/tests/serialize_test.py b/tests/serialize_test.py new file mode 100644 index 000000000..819bda14b --- /dev/null +++ b/tests/serialize_test.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +# pylint: disable=protected-access +"""Unit test for serialization.""" +import json +import unittest + +from agentscope.message import Msg, PlaceholderMessage +from agentscope.serialize import serialize, deserialize + + +class SerializationTest(unittest.TestCase): + """The test cases for serialization.""" + + def test_serialize(self) -> None: + """Test the serialization function.""" + + msg1 = Msg("A", "A", "assistant") + msg2 = Msg("B", "B", "assistant") + placeholder = PlaceholderMessage( + host="localhost", + port=50051, + ) + + serialized_msg1 = serialize(msg1) + deserialized_msg1 = deserialize(serialized_msg1) + self.assertTrue(isinstance(serialized_msg1, str)) + self.assertTrue(isinstance(deserialized_msg1, Msg)) + + msg1_dict = json.loads(serialized_msg1) + self.assertDictEqual( + msg1_dict, + { + "id": msg1.id, + "name": msg1.name, + "content": msg1.content, + "role": msg1.role, + "timestamp": msg1.timestamp, + "metadata": msg1.metadata, + "url": msg1.url, + "__module__": "agentscope.message.msg", + "__name__": "Msg", + }, + ) + + serialized_list = serialize([msg1, msg2]) + deserialized_list = deserialize(serialized_list) + self.assertTrue(isinstance(serialized_list, str)) + self.assertTrue( + isinstance(deserialized_list, list) + and len(deserialized_list) == 2 + and all(isinstance(msg, Msg) for msg in deserialized_list), + ) + + dict_list = json.loads(serialized_list) + self.assertListEqual( + dict_list, + [ + { + "id": msg1.id, + "name": msg1.name, + "content": msg1.content, + "role": msg1.role, + "timestamp": msg1.timestamp, + "metadata": msg1.metadata, + "url": msg1.url, + "__module__": "agentscope.message.msg", + "__name__": "Msg", + }, + { + "id": msg2.id, + "name": msg2.name, + "content": msg2.content, + "role": msg2.role, + "timestamp": msg2.timestamp, + "metadata": msg2.metadata, + "url": msg2.url, + "__module__": "agentscope.message.msg", + "__name__": "Msg", + }, + ], + ) + + serialized_placeholder = serialize(placeholder) + deserialized_placeholder = deserialize(serialized_placeholder) + self.assertTrue(isinstance(serialized_placeholder, str)) + self.assertTrue( + isinstance(deserialized_placeholder, PlaceholderMessage), + ) + + placeholder_dict = json.loads(serialized_placeholder) + self.assertDictEqual( + placeholder_dict, + { + "_host": placeholder._host, + "_port": placeholder._port, + "_task_id": placeholder._task_id, + "__module__": "agentscope.message.placeholder", + "__name__": "PlaceholderMessage", + }, + )