Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
DavdGao committed May 9, 2024
1 parent 1b98480 commit 8f5c45f
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 27 deletions.
14 changes: 7 additions & 7 deletions docs/sphinx_doc/en/source/tutorial/206-prompt.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,11 @@ messages as input. The message must obey the following rules (updated in

#### Prompt Strategy

Given a list of messages, we will parse each message as follows:

- `Msg`: Fill the `role` and `content` fields directly. If it has an `url`
field, which refers to an image, we will add it to the message.
- `List`: Parse each element in the list according to the above rules.
- If the role field of the first input message is `"system"`,
it will be treated as system prompt and the other messages will consist
dialogue history in the system message prefixed by "## Dialogue History".
- If the `url` attribute of messages is not `None`, we will gather all urls in
the `images" field in the returned dictionary.

```python
from agentscope.models import OllamaChatWrapper
Expand All @@ -269,8 +269,8 @@ print(prompt)
```bash
[
{
"role": "system",
"content": "You are a helpful assistant\n\n## Dialogue History\nBob: Hi.\nAlice: Nice to meet you!",
"role": "system",
"content": "You are a helpful assistant\n\n## Dialogue History\nBob: Hi.\nAlice: Nice to meet you!",
"images": ["https://example.com/image.jpg"]
},
]
Expand Down
10 changes: 6 additions & 4 deletions docs/sphinx_doc/zh_CN/source/tutorial/206-prompt.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,10 @@ print(prompt)

给定一个消息列表,我们将按照以下规则解析每个消息:

- `Msg`:直接填充`role``content`字段。如果它有一个`url`字段,指向一个图片,我们将把它添加到消息中。
- `List`:根据上述规则解析列表中的每个元素。
- 如果输入的第一条信息的`role`字段是`"system"`,该条信息将被视为系统提示(system
prompt),其他信息将一起组成对话历史。对话历史将添加`"## Dialogue History"`的前缀,并与
系统提示一起组成一条`role``"system"`的信息。
- 如果输入信息中的`url`字段不为`None`,则这些url将一起被置于`"images"`对应的键值中。

```python
from agentscope.models import OllamaChatWrapper
Expand All @@ -236,8 +238,8 @@ print(prompt)
```bash
[
{
"role": "system",
"content": "You are a helpful assistant\n\n## Dialogue History\nBob: Hi.\nAlice: Nice to meet you!",
"role": "system",
"content": "You are a helpful assistant\n\n## Dialogue History\nBob: Hi.\nAlice: Nice to meet you!",
"images": ["https://example.com/image.jpg"]
},
]
Expand Down
30 changes: 17 additions & 13 deletions src/agentscope/models/ollama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,23 +235,25 @@ def format(
)

# record dialog history as a list of strings
system_prompt = None
system_content_template = []
dialogue = []
# TODO: here we default the url links to images
images = []
for i, unit in enumerate(input_msgs):
if i == 0 and unit.role == "system":
# system prompt
system_prompt = _convert_to_str(unit.content)
if not system_prompt.endswith("\n"):
system_prompt += "\n"
system_content_template.append(system_prompt)
else:
# Merge all messages into a dialogue history prompt
dialogue.append(
f"{unit.name}: {_convert_to_str(unit.content)}",
)

system_content_template = []
if system_prompt is not None:
system_content_template.append("{system_prompt}")
if unit.url is not None:
images.append(unit.url)

if len(dialogue) != 0:
system_content_template.extend(
Expand All @@ -262,15 +264,17 @@ def format(

system_content_template = "\n".join(system_content_template)

return [
{
"role": "system",
"content": system_content_template.format(
system_prompt=system_prompt,
dialogue_history=dialogue_history,
),
},
]
system_message = {
"role": "system",
"content": system_content_template.format(
dialogue_history=dialogue_history,
),
}

if len(images) != 0:
system_message["images"] = images

return [system_message]


class OllamaEmbeddingWrapper(OllamaWrapperBase):
Expand Down
13 changes: 10 additions & 3 deletions tests/format_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,16 @@ def test_ollama_chat(self) -> None:

# correct format
ground_truth = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "What is the weather today?"},
{"role": "assistant", "content": "It is sunny today"},
{
"role": "system",
"content": (
"You are a helpful assistant\n"
"\n"
"## Dialogue History\n"
"user: What is the weather today?\n"
"assistant: It is sunny today"
),
},
]
prompt = model.format(*self.inputs) # type: ignore[arg-type]
self.assertEqual(prompt, ground_truth)
Expand Down

0 comments on commit 8f5c45f

Please sign in to comment.