Skip to content

Commit

Permalink
[HOTFIX] Update prompt strategy in Ollama chat API. (#206)
Browse files Browse the repository at this point in the history
  • Loading branch information
DavdGao authored May 11, 2024
1 parent b8ed068 commit 9d4a1bb
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 49 deletions.
18 changes: 10 additions & 8 deletions docs/sphinx_doc/en/source/tutorial/206-prompt.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,11 @@ messages as input. The message must obey the following rules (updated in

#### Prompt Strategy

Given a list of messages, we will parse each message as follows:

- `Msg`: Fill the `role` and `content` fields directly. If it has an `url`
field, which refers to an image, we will add it to the message.
- `List`: Parse each element in the list according to the above rules.
- If the role field of the first input message is `"system"`,
it will be treated as system prompt and the other messages will consist
dialogue history in the system message prefixed by "## Dialogue History".
- If the `url` attribute of messages is not `None`, we will gather all urls in
the `"images"` field in the returned dictionary.

```python
from agentscope.models import OllamaChatWrapper
Expand All @@ -268,9 +268,11 @@ print(prompt)

```bash
[
{"role": "system", "content": "You are a helpful assistant"},
{"role": "assistant", "content": "Hi."},
{"role": "assistant", "content": "Nice to meet you!", "images": ["https://example.com/image.jpg"]},
{
"role": "system",
"content": "You are a helpful assistant\n\n## Dialogue History\nBob: Hi.\nAlice: Nice to meet you!",
"images": ["https://example.com/image.jpg"]
},
]
```

Expand Down
14 changes: 9 additions & 5 deletions docs/sphinx_doc/zh_CN/source/tutorial/206-prompt.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,10 @@ print(prompt)

给定一个消息列表,我们将按照以下规则解析每个消息:

- `Msg`:直接填充`role``content`字段。如果它有一个`url`字段,指向一个图片,我们将把它添加到消息中。
- `List`:根据上述规则解析列表中的每个元素。
- 如果输入的第一条信息的`role`字段是`"system"`,该条信息将被视为系统提示(system
prompt),其他信息将一起组成对话历史。对话历史将添加`"## Dialogue History"`的前缀,并与
系统提示一起组成一条`role``"system"`的信息。
- 如果输入信息中的`url`字段不为`None`,则这些url将一起被置于`"images"`对应的键值中。

```python
from agentscope.models import OllamaChatWrapper
Expand All @@ -235,9 +237,11 @@ print(prompt)

```bash
[
{"role": "system", "content": "You are a helpful assistant"},
{"role": "assistant", "content": "Hi."},
{"role": "assistant", "content": "Nice to meet you!", "images": ["https://example.com/image.jpg"]},
{
"role": "system",
"content": "You are a helpful assistant\n\n## Dialogue History\nBob: Hi.\nAlice: Nice to meet you!",
"images": ["https://example.com/image.jpg"]
},
]
```

Expand Down
123 changes: 90 additions & 33 deletions src/agentscope/models/ollama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from abc import ABC
from typing import Sequence, Any, Optional, List, Union

from loguru import logger

from agentscope.message import MessageBase
from agentscope.models import ModelWrapperBase, ModelResponse
from agentscope.utils.tools import _convert_to_str
Expand Down Expand Up @@ -170,10 +168,43 @@ def format(
self,
*args: Union[MessageBase, Sequence[MessageBase]],
) -> List[dict]:
"""A basic strategy to format the input into the required format of
Ollama Chat API.
"""Format the messages for ollama Chat API.
All messages will be formatted into a single system message with
system prompt and dialogue history.
Note:
1. This strategy maybe not suitable for all scenarios,
and developers are encouraged to implement their own prompt
engineering strategies.
2. For ollama chat api, the content field shouldn't be empty string.
Example:
.. code-block:: python
prompt = model.format(
Msg("system", "You're a helpful assistant", role="system"),
Msg("Bob", "Hi, how can I help you?", role="assistant"),
Msg("user", "What's the date today?", role="user")
)
The prompt will be as follows:
.. code-block:: python
[
{
"role": "user",
"content": (
"You're a helpful assistant\\n\\n"
"## Dialogue History\\n"
"Bob: Hi, how can I help you?\\n"
"user: What's the date today?"
)
}
]
Note for ollama chat api, the content field shouldn't be empty string.
Args:
args (`Union[MessageBase, Sequence[MessageBase]]`):
Expand All @@ -185,39 +216,65 @@ def format(
`List[dict]`:
The formatted messages.
"""
ollama_msgs = []
for msg in args:
if msg is None:
continue
if isinstance(msg, MessageBase):
# content shouldn't be empty string
if msg.content == "":
logger.warning(
"In ollama chat API, the content field cannot be "
"empty string. To avoid error, the empty string is "
"replaced by a blank space automatically, but the "
"model may not work as expected.",
)
msg.content = " "

ollama_msg = {
"role": msg.role,
"content": _convert_to_str(msg.content),
}

# image url
if msg.url is not None:
ollama_msg["images"] = [msg.url]

ollama_msgs.append(ollama_msg)
elif isinstance(msg, list):
ollama_msgs.extend(self.format(*msg))
# Parse all information into a list of messages
input_msgs = []
for _ in args:
if _ is None:
continue
if isinstance(_, MessageBase):
input_msgs.append(_)
elif isinstance(_, list) and all(
isinstance(__, MessageBase) for __ in _
):
input_msgs.extend(_)
else:
raise TypeError(
f"Invalid message type: {type(msg)}, `Msg` is expected.",
f"The input should be a Msg object or a list "
f"of Msg objects, got {type(_)}.",
)

return ollama_msgs
# record dialog history as a list of strings
system_content_template = []
dialogue = []
# TODO: here we default the url links to images
images = []
for i, unit in enumerate(input_msgs):
if i == 0 and unit.role == "system":
# system prompt
system_prompt = _convert_to_str(unit.content)
if not system_prompt.endswith("\n"):
system_prompt += "\n"
system_content_template.append(system_prompt)
else:
# Merge all messages into a dialogue history prompt
dialogue.append(
f"{unit.name}: {_convert_to_str(unit.content)}",
)

if unit.url is not None:
images.append(unit.url)

if len(dialogue) != 0:
system_content_template.extend(
["## Dialogue History", "{dialogue_history}"],
)

dialogue_history = "\n".join(dialogue)

system_content_template = "\n".join(system_content_template)

system_message = {
"role": "system",
"content": system_content_template.format(
dialogue_history=dialogue_history,
),
}

if len(images) != 0:
system_message["images"] = images

return [system_message]


class OllamaEmbeddingWrapper(OllamaWrapperBase):
Expand Down
13 changes: 10 additions & 3 deletions tests/format_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,16 @@ def test_ollama_chat(self) -> None:

# correct format
ground_truth = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "What is the weather today?"},
{"role": "assistant", "content": "It is sunny today"},
{
"role": "system",
"content": (
"You are a helpful assistant\n"
"\n"
"## Dialogue History\n"
"user: What is the weather today?\n"
"assistant: It is sunny today"
),
},
]
prompt = model.format(*self.inputs) # type: ignore[arg-type]
self.assertEqual(prompt, ground_truth)
Expand Down

0 comments on commit 9d4a1bb

Please sign in to comment.