Skip to content

Commit

Permalink
[Hot fix]Fix parse_func and fault_handler of ModelWrapper (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
pan-x-c authored Feb 21, 2024
1 parent 1b30932 commit 56478f3
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 11 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,8 @@ from agentscope.agents import AgentBase

class MyAgent(AgentBase):
def reply(self, x):
# Do something here, e.g. calling your model
response = self.model(x)
# Do something here, e.g. calling your model and get the raw field as your agent's response
response = self.model(x).raw
return response
```

Expand Down
4 changes: 2 additions & 2 deletions docs/sphinx_doc/source/tutorial/201-agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,10 @@ def reply(self, x: dict = None) -> dict:
prompt = self.engine.join(self.sys_prompt, self.memory.get_memory())

# Invoke the language model with the prepared prompt
response = self.model(prompt, parse_func=json.loads, fault_handler=lambda x: {"speak": x})
response = self.model(prompt).text

# Format the response and create a message object
msg = Msg(self.name, response.get("speak", None) or response, **response)
msg = Msg(self.name, response)

# Record the message to memory and return it
self.memory.add(msg)
Expand Down
4 changes: 2 additions & 2 deletions docs/sphinx_doc/source/tutorial/204-service.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,12 @@ class YourAgent(AgentBase):
prompt += params_prompt

# Get the model response
model_response = self.model(prompt)
model_response = self.model(prompt).text

# Parse the model response and call the create_file function
# Additional extraction functions might be necessary
try:
kwargs = json.loads(model_response.content)
kwargs = json.loads(model_response)
create_file(**kwargs)
except:
# Error handling
Expand Down
23 changes: 18 additions & 5 deletions src/agentscope/agents/dict_dialog_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,21 @@

from ..message import Msg
from .agent import AgentBase
from ..models.model import ModelResponse
from ..prompt import PromptEngine
from ..prompt import PromptType


def parse_dict(response: ModelResponse) -> ModelResponse:
"""Parse function for DictDialogAgent"""
return ModelResponse(raw=json.loads(response.text))


def default_response(response: ModelResponse) -> ModelResponse:
"""The default response of fault_handler"""
return ModelResponse(raw={"speak": response.text})


class DictDialogAgent(AgentBase):
"""An agent that generates response in a dict format, where user can
specify the required fields in the response via prompt, e.g.
Expand Down Expand Up @@ -40,8 +51,8 @@ def __init__(
model_config_name: str = None,
use_memory: bool = True,
memory_config: Optional[dict] = None,
parse_func: Optional[Callable[..., Any]] = json.loads,
fault_handler: Optional[Callable[..., Any]] = lambda x: {"speak": x},
parse_func: Optional[Callable[..., Any]] = parse_dict,
fault_handler: Optional[Callable[..., Any]] = default_response,
max_retries: Optional[int] = 3,
prompt_type: Optional[PromptType] = PromptType.LIST,
) -> None:
Expand All @@ -60,11 +71,13 @@ def __init__(
Whether the agent has memory.
memory_config (`Optional[dict]`, defaults to `None`):
The config of memory.
parse_func (`Optional[Callable[..., Any]]`, defaults to `None`):
parse_func (`Optional[Callable[..., Any]]`,
defaults to `parse_dict`):
The function used to parse the model output,
e.g. `json.loads`, which is used to extract json from the
output.
fault_handler (`Optional[Callable[..., Any]]`, defaults to `None`):
fault_handler (`Optional[Callable[..., Any]]`,
defaults to `default_response`):
The function used to handle the fault when parse_func fails
to parse the model output.
max_retries (`Optional[int]`, defaults to `None`):
Expand Down Expand Up @@ -129,7 +142,7 @@ def reply(self, x: dict = None) -> dict:
parse_func=self.parse_func,
fault_handler=self.fault_handler,
max_retries=self.max_retries,
).text
).raw

# logging raw messages in debug mode
logger.debug(json.dumps(response, indent=4))
Expand Down
9 changes: 9 additions & 0 deletions src/agentscope/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,15 @@ def raw(self) -> dict:
"""Raw dictionary field."""
return self._raw

def __str__(self) -> str:
serialized_fields = {
"text": self.text,
"embedding": self.embedding,
"image_urls": self.image_urls,
"raw": self.raw,
}
return json.dumps(serialized_fields, indent=4)


def _response_parse_decorator(
model_call: Callable,
Expand Down

0 comments on commit 56478f3

Please sign in to comment.