Skip to content

Commit

Permalink
[Hotfix] Fix memory loading error in TemporaryMemory class (#197)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiTao-Li authored May 11, 2024
1 parent 6d30bf9 commit b8ed068
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 58 deletions.
55 changes: 48 additions & 7 deletions src/agentscope/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
"""

from abc import ABC, abstractmethod
from typing import Iterable
from typing import Iterable, Sequence
from typing import Optional
from typing import Union
from typing import Callable

from ..message import MessageBase


class MemoryBase(ABC):
"""Base class for memory."""
Expand All @@ -33,6 +35,8 @@ def __init__(
def update_config(self, config: dict) -> None:
"""
Configure memory as specified in config
Args:
config (`dict`): Configuration of resetting this memory
"""
self.config = config

Expand All @@ -43,41 +47,78 @@ def get_memory(
filter_func: Optional[Callable[[int, dict], bool]] = None,
) -> list:
"""
Return a certain range (`recent_n` or all) of memory, filtered by
`filter_func`
Return a certain range (`recent_n` or all) of memory,
filtered by `filter_func`
Args:
recent_n (int, optional):
indicate the most recent N memory pieces to be returned.
filter_func (Optional[Callable[[int, dict], bool]]):
filter function to decide which pieces of memory should
be returned, taking the index and a piece of memory as
input and return True (return this memory) or False
(does not return)
"""

@abstractmethod
def add(self, memories: Union[list[dict], dict, None]) -> None:
def add(
self,
memories: Union[Sequence[dict], dict, None],
) -> None:
"""
Adding new memory fragment, depending on how the memory are stored
Args:
memories (Union[Sequence[dict], dict, None]):
Memories to be added. If the memory is not in MessageBase,
it will first be converted into a message type.
"""

@abstractmethod
def delete(self, index: Union[Iterable, int]) -> None:
"""
Delete memory fragment, depending on how the memory are stored
and matched
Args:
index (Union[Iterable, int]):
indices of the memory fragments to delete
"""

@abstractmethod
def load(
self,
memories: Union[str, dict, list],
memories: Union[str, list[MessageBase], MessageBase],
overwrite: bool = False,
) -> None:
"""
Load memory, depending on how the memory are passed, design to load
from both file or dict
Args:
memories (Union[str, list[MessageBase], MessageBase]):
memories to be loaded.
If it is in str type, it will be first checked if it is a
file; otherwise it will be deserialized as messages.
Otherwise, memories must be either in message type or list
of messages.
overwrite (bool):
if True, clear the current memory before loading the new ones;
if False, memories will be appended to the old one at the end.
"""

@abstractmethod
def export(
self,
to_mem: bool = False,
file_path: Optional[str] = None,
to_mem: bool = False,
) -> Optional[list]:
"""Export memory, depending on how the memory are stored"""
"""
Export memory, depending on how the memory are stored
Args:
file_path (Optional[str]):
file path to save the memory to.
to_mem (Optional[str]):
if True, just return the list of messages in memory
Notice: this method prevents file_path is None when to_mem
is False.
"""

@abstractmethod
def clear(self) -> None:
Expand Down
94 changes: 87 additions & 7 deletions src/agentscope/memory/temporary_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@
from ..models import load_model_by_config_name
from ..service.retrieval.retrieval_from_list import retrieve_from_list
from ..service.retrieval.similarity import Embedding
from ..message import (
deserialize,
serialize,
MessageBase,
Msg,
Tht,
PlaceholderMessage,
)


class TemporaryMemory(MemoryBase):
Expand All @@ -28,6 +36,16 @@ def __init__(
config: Optional[dict] = None,
embedding_model: Union[str, Callable] = None,
) -> None:
"""
Temporary memory module for conversation.
Args:
config (dict):
configuration of the memory
embedding_model (Union[str, Callable])
if the temporary memory needs to be embedded,
then either pass the name of embedding model or
the embedding model itself.
"""
super().__init__(config)

self._content = []
Expand All @@ -43,17 +61,48 @@ def add(
memories: Union[Sequence[dict], dict, None],
embed: bool = False,
) -> None:
# pylint: disable=too-many-branches
"""
Adding new memory fragment, depending on how the memory are stored
Args:
memories (Union[Sequence[dict], dict, None]):
memories to be added. If the memory is not in MessageBase,
it will first be converted into a message type.
embed (bool):
whether to generate embedding for the new added memories
"""
if memories is None:
return

if not isinstance(memories, list):
if not isinstance(memories, Sequence):
record_memories = [memories]
else:
record_memories = memories

# if memory doesn't have id attribute, we skip the checking
memories_idx = set(_.id for _ in self._content if hasattr(_, "id"))
for memory_unit in record_memories:
if not issubclass(type(memory_unit), MessageBase):
try:
if (
"name" in memory_unit
and memory_unit["name"] == "thought"
):
memory_unit = Tht(**memory_unit)
else:
memory_unit = Msg(**memory_unit)
except Exception as exc:
raise ValueError(
f"Cannot add {memory_unit} to memory, "
f"must be with subclass of MessageBase",
) from exc

# in case this is a PlaceholderMessage, try to update
# the values first
if isinstance(memory_unit, PlaceholderMessage):
memory_unit.update_value()
memory_unit = Msg(**memory_unit)

# add to memory if it's new
if (
not hasattr(memory_unit, "id")
Expand All @@ -71,6 +120,13 @@ def add(
self._content.append(memory_unit)

def delete(self, index: Union[Iterable, int]) -> None:
"""
Delete memory fragment, depending on how the memory are stored
and matched
Args:
index (Union[Iterable, int]):
indices of the memory fragments to delete
"""
if self.size() == 0:
logger.warning(
"The memory is empty, and the delete operation is "
Expand Down Expand Up @@ -101,16 +157,26 @@ def delete(self, index: Union[Iterable, int]) -> None:

def export(
self,
to_mem: bool = False,
file_path: Optional[str] = None,
to_mem: bool = False,
) -> Optional[list]:
"""Export memory to json file"""
"""
Export memory, depending on how the memory are stored
Args:
file_path (Optional[str]):
file path to save the memory to. The messages will
be serialized and written to the file.
to_mem (Optional[str]):
if True, just return the list of messages in memory
Notice: this method prevents file_path is None when to_mem
is False.
"""
if to_mem:
return self._content

if to_mem is False and file_path is not None:
with open(file_path, "w", encoding="utf-8") as f:
json.dump(self._content, f, indent=4)
f.write(serialize(self._content))
else:
raise NotImplementedError(
"file type only supports "
Expand All @@ -120,16 +186,30 @@ def export(

def load(
self,
memories: Union[str, dict, list],
memories: Union[str, list[MessageBase], MessageBase],
overwrite: bool = False,
) -> None:
"""
Load memory, depending on how the memory are passed, design to load
from both file or dict
Args:
memories (Union[str, list[MessageBase], MessageBase]):
memories to be loaded.
If it is in str type, it will be first checked if it is a
file; otherwise it will be deserialized as messages.
Otherwise, memories must be either in message type or list
of messages.
overwrite (bool):
if True, clear the current memory before loading the new ones;
if False, memories will be appended to the old one at the end.
"""
if isinstance(memories, str):
if os.path.isfile(memories):
with open(memories, "r", encoding="utf-8") as f:
self.add(json.load(f))
load_memories = deserialize(f.read())
else:
try:
load_memories = json.loads(memories)
load_memories = deserialize(memories)
if not isinstance(load_memories, dict) and not isinstance(
load_memories,
list,
Expand Down
8 changes: 7 additions & 1 deletion src/agentscope/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,18 @@ def __init__(
self,
content: Any,
timestamp: Optional[str] = None,
**kwargs: Any,
) -> None:
if "name" in kwargs:
kwargs.pop("name")
if "role" in kwargs:
kwargs.pop("role")
super().__init__(
name="thought",
content=content,
role="assistant",
timestamp=timestamp,
**kwargs,
)

def to_str(self) -> str:
Expand Down Expand Up @@ -399,7 +405,7 @@ def deserialize(s: str) -> Union[MessageBase, Sequence]:
return [deserialize(s) for s in js_msg["__value"]]
elif msg_type not in _MSGS:
raise NotImplementedError(
"Deserialization of {msg_type} is not supported.",
f"Deserialization of {msg_type} is not supported.",
)
return _MSGS[msg_type](**js_msg)

Expand Down
Loading

0 comments on commit b8ed068

Please sign in to comment.