From 86b665b502d18d65edbeca3844bcd3e1612c0841 Mon Sep 17 00:00:00 2001 From: zhiweiliu Date: Thu, 9 May 2024 11:56:29 -0700 Subject: [PATCH] updating forward in agent to support customized finish actions. --- agentlite/agents/BaseAgent.py | 39 ++++++++++++++++++++------------ agentlite/agents/ManagerAgent.py | 17 ++++++-------- 2 files changed, 32 insertions(+), 24 deletions(-) diff --git a/agentlite/agents/BaseAgent.py b/agentlite/agents/BaseAgent.py index 1738a19..e8698ff 100644 --- a/agentlite/agents/BaseAgent.py +++ b/agentlite/agents/BaseAgent.py @@ -32,7 +32,7 @@ class BaseAgent(ABCAgent): Your generation should follow the example format. Finish the task as best as you can.". PROMPT_TOKENS is defined in agentlite/agent_prompts/prompt_utils.py :type instruction: str, optional - :param reasoning_type: the reasoning type of this agent, defaults to "react" + :param reasoning_type: the reasoning type of this agent, defaults to "react". See BaseAgent.__add_inner_actions__ for more details. :type reasoning_type: str, optional :param logger: the logger for this agent, defaults to DefaultLogger :type logger: AgentLogger, optional @@ -95,8 +95,10 @@ def __add_inner_actions__(self): elif self.reasoning_type == "planreact": self.actions += [PlanAct, ThinkAct, FinishAct] else: - Warning("Not yet supported. Will using react instead.") - self.actions += [ThinkAct, FinishAct] + Warning("Not yet supported. Will only use input actions.") + # check if a finish action is in the action space + if not self.__check_action__(FinishAct.action_name): + Warning("Finish action is not in the action space.\n Should add an action with BaseAction.action_name==\"Finish\".") self.actions = list(set(self.actions)) def __call__(self, task: TaskPackage) -> str: @@ -228,18 +230,16 @@ def forward(self, task: TaskPackage, agent_act: AgentAct) -> str: :rtype: str """ act_found_flag = False - # if action is Finish Action - if agent_act.name == FinishAct.action_name: - act_found_flag = True - observation = "Task Completed." - task.completion = "completed" - task.answer = FinishAct(**agent_act.params) + # if match one in self.actions - else: - for action in self.actions: - if act_match(agent_act.name, action): - act_found_flag = True - observation = action(**agent_act.params) + for action in self.actions: + if act_match(agent_act.name, action): + act_found_flag = True + observation = action(**agent_act.params) + # if action is Finish Action + if agent_act.name == FinishAct.action_name: + task.answer = observation + task.completion = "completed" # if not find this action if act_found_flag: return observation @@ -263,3 +263,14 @@ def add_example( :type example_type: str, optional """ self.prompt_gen.add_example(task, action_chain, example_type=example_type) + + def __check_action__(self, action_name:str): + """check if the action is in the action space + + :param action_name: the name of the action + :type action_name: str + """ + for action in self.actions: + if act_match(action_name, action): + return True + return False \ No newline at end of file diff --git a/agentlite/agents/ManagerAgent.py b/agentlite/agents/ManagerAgent.py index 4594751..71c15a2 100644 --- a/agentlite/agents/ManagerAgent.py +++ b/agentlite/agents/ManagerAgent.py @@ -157,16 +157,13 @@ def forward(self, task: TaskPackage, agent_act: AgentAct) -> str: observation = agent(new_task_package) return observation # if action is inner action - if agent_act.name == FinishAct.action_name: - act_found_flag = True - observation = "Task Completed." - task.completion = "completed" - task.answer = FinishAct(**agent_act.params) - else: - for action in self.actions: - if act_match(agent_act.name, action): - act_found_flag = True - observation = action(**agent_act.params) + for action in self.actions: + if act_match(agent_act.name, action): + act_found_flag = True + observation = action(**agent_act.params) + if agent_act.name == FinishAct.action_name: + task.answer = observation + task.completion = "completed" # if not find this action if act_found_flag: return observation