Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix error in agent when use_memory is False #74

Merged
merged 4 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions docs/sphinx_doc/en/source/tutorial/201-agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,14 @@ Below, we provide usages of how to configure various agents from the AgentPool:
def reply(self, x: dict = None) -> dict:
# Additional processing steps can occur here

if x is not None:
if self.memory:
self.memory.add(x) # Update the memory with the input

# Generate a prompt for the language model using the system prompt and memory
prompt = self.engine.join(self.sys_prompt, self.memory.get_memory())
prompt = self.engine.join(
self.sys_prompt,
self.memory and self.memory.get_memory(),
)

# Invoke the language model with the prepared prompt
response = self.model(prompt).text
Expand Down Expand Up @@ -129,7 +132,7 @@ def reply(
required_keys: Optional[Union[list[str], str]] = None,
) -> dict:
# Check if there is initial data to be added to memory
if x is not None:
if self.memory:
self.memory.add(x)

content = input(f"{self.name}: ") # Prompt the user for input
Expand All @@ -151,7 +154,8 @@ def reply(
msg = Msg(self.name, content=content, url=url, **kwargs)

# Add the message object to memory
self.memory.add(msg)
if self.memory:
self.memory.add(msg)
return msg
```

Expand Down
16 changes: 11 additions & 5 deletions docs/sphinx_doc/zh_CN/source/tutorial_zh/201-agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,14 @@ class AgentBase(Operator):
def reply(self, x: dict = None) -> dict:
# Additional processing steps can occur here

if x is not None:
if self.memory:
self.memory.add(x) # Update the memory with the input

# Generate a prompt for the language model using the system prompt and memory
prompt = self.engine.join(self.sys_prompt, self.memory.get_memory())
prompt = self.engine.join(
self.sys_prompt,
self.memory and self.memory.get_memory(),
)

# Invoke the language model with the prepared prompt
response = self.model(prompt).text
Expand All @@ -97,7 +100,8 @@ def reply(self, x: dict = None) -> dict:
msg = Msg(self.name, response)

# Record the message to memory and return it
self.memory.add(msg)
if self.memory:
self.memory.add(msg)
return msg
```

Expand Down Expand Up @@ -130,7 +134,7 @@ def reply(
required_keys: Optional[Union[list[str], str]] = None,
) -> dict:
# Check if there is initial data to be added to memory
if x is not None:
if self.memory:
self.memory.add(x)

content = input(f"{self.name}: ") # Prompt the user for input
Expand All @@ -152,7 +156,9 @@ def reply(
msg = Msg(self.name, content=content, url=url, **kwargs)

# Add the message object to memory
self.memory.add(msg)
if self.memory:
self.memory.add(msg)

return msg
```

Expand Down
2 changes: 2 additions & 0 deletions src/agentscope/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def __init__(

if use_memory:
self.memory = TemporaryMemory(memory_config)
else:
self.memory = None

# The audience of this agent, which means if this agent generates a
# response, it will be passed to all agents in the audience.
Expand Down
13 changes: 7 additions & 6 deletions src/agentscope/agents/dialog_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class DialogAgent(AgentBase):
def __init__(
self,
name: str,
sys_prompt: Optional[str] = None,
model_config_name: str = None,
sys_prompt: str,
model_config_name: str,
use_memory: bool = True,
memory_config: Optional[dict] = None,
prompt_type: Optional[PromptType] = PromptType.LIST,
Expand All @@ -29,7 +29,7 @@ def __init__(
sys_prompt (`Optional[str]`):
The system prompt of the agent, which can be passed by args
or hard-coded in the agent.
model_config_name (`str`, defaults to None):
model_config_name (`str`):
The name of the model config, which is used to load model from
configuration.
use_memory (`bool`, defaults to `True`):
Expand Down Expand Up @@ -68,13 +68,13 @@ def reply(self, x: dict = None) -> dict:
response to the user's input.
"""
# record the input if needed
if x is not None:
if self.memory:
DavdGao marked this conversation as resolved.
Show resolved Hide resolved
self.memory.add(x)

# prepare prompt
prompt = self.engine.join(
self.sys_prompt,
self.memory.get_memory(),
self.memory and self.memory.get_memory(),
)

# call llm and generate response
Expand All @@ -85,6 +85,7 @@ def reply(self, x: dict = None) -> dict:
self.speak(msg)

# Record the message in memory
self.memory.add(msg)
if self.memory:
self.memory.add(msg)

return msg
11 changes: 6 additions & 5 deletions src/agentscope/agents/dict_dialog_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class DictDialogAgent(AgentBase):
def __init__(
self,
name: str,
sys_prompt: Optional[str] = None,
model_config_name: str = None,
sys_prompt: str,
model_config_name: str,
use_memory: bool = True,
memory_config: Optional[dict] = None,
parse_func: Optional[Callable[..., Any]] = parse_dict,
Expand Down Expand Up @@ -127,13 +127,13 @@ def reply(self, x: dict = None) -> dict:
it defaults to treating the response as plain text.
"""
# record the input if needed
if x is not None:
if self.memory:
self.memory.add(x)

# prepare prompt
prompt = self.engine.join(
self.sys_prompt,
self.memory.get_memory(),
self.memory and self.memory.get_memory(),
)

# call llm
Expand All @@ -158,6 +158,7 @@ def reply(self, x: dict = None) -> dict:
self.speak(msg)

# record to memory
self.memory.add(msg)
if self.memory:
self.memory.add(msg)

return msg
12 changes: 8 additions & 4 deletions src/agentscope/agents/text_to_image_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ class TextToImageAgent(AgentBase):
def __init__(
self,
name: str,
sys_prompt: Optional[str] = None,
model_config_name: str = None,
sys_prompt: str,
model_config_name: str,
use_memory: bool = True,
memory_config: Optional[dict] = None,
) -> None:
Expand Down Expand Up @@ -44,8 +44,9 @@ def __init__(
)

def reply(self, x: dict = None) -> dict:
if x is not None:
if self.memory:
self.memory.add(x)

image_urls = self.model(x.content).image_urls
# TODO: optimize the construction of content
msg = Msg(
Expand All @@ -54,5 +55,8 @@ def reply(self, x: dict = None) -> dict:
url=image_urls,
)
logger.chat(msg)
self.memory.add(msg)

if self.memory:
self.memory.add(msg)

return msg
5 changes: 3 additions & 2 deletions src/agentscope/agents/user_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def reply(
the user's input and any additional details. This is also
stored in the object's memory.
"""
if x is not None:
if self.memory:
self.memory.add(x)

# TODO: To avoid order confusion, because `input` print much quicker
Expand Down Expand Up @@ -91,6 +91,7 @@ def reply(
)

# Add to memory
self.memory.add(msg)
if self.memory:
self.memory.add(msg)

return msg
2 changes: 1 addition & 1 deletion src/agentscope/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_memory(
"""

@abstractmethod
def add(self, memories: Union[list[dict], dict]) -> None:
def add(self, memories: Union[list[dict], dict, None]) -> None:
"""
Adding new memory fragment, depending on how the memory are stored
"""
Expand Down
5 changes: 4 additions & 1 deletion src/agentscope/memory/temporary_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,12 @@ def __init__(

def add(
self,
memories: Union[Sequence[dict], dict],
memories: Union[Sequence[dict], dict, None],
embed: bool = False,
) -> None:
if memories is None:
return

if not isinstance(memories, list):
record_memories = [memories]
else:
Expand Down
4 changes: 4 additions & 0 deletions src/agentscope/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def join(
converted to `Msg` from `system`.
"""
# TODO: achieve the summarize function

# Filter `None`
args = [_ for _ in args if _ is not None]

if self.prompt_type == PromptType.STRING:
return self.join_to_str(*args, format_map=format_map)
elif self.prompt_type == PromptType.LIST:
Expand Down
4 changes: 2 additions & 2 deletions src/agentscope/service/web_search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def bing_search(
Example:
.. code-block:: python

results = _search_bing(question="What is an agent?",
results = bing_search(question="What is an agent?",
bing_api_key="your bing api key",
num_results=2,
mkt="en-US")
Expand Down Expand Up @@ -150,7 +150,7 @@ def google_search(
Example:
.. code-block:: python

results = _search_google(
results = google_search(
'Python programming',
'your_google_api_key',
'your_cse_id',
Expand Down
3 changes: 1 addition & 2 deletions tests/rpc_agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ class DemoRpcAgentWithMemory(AgentBase):
"""A demo Rpc agent that count its memory"""

def reply(self, x: dict = None) -> dict:
if x is not None:
self.memory.add(x)
self.memory.add(x)
msg = Msg(name=self.name, content={"mem_size": self.memory.size()})
self.memory.add(msg)
time.sleep(1)
Expand Down
Loading