From 2ce4cee5cd20a7136ed1a9fc6021f9fff60f56d2 Mon Sep 17 00:00:00 2001 From: DavdGao Date: Thu, 9 May 2024 20:12:43 +0800 Subject: [PATCH] update prompt strategy in ollama chat api. --- src/agentscope/models/ollama_model.py | 119 +++++++++++++++++++------- 1 file changed, 86 insertions(+), 33 deletions(-) diff --git a/src/agentscope/models/ollama_model.py b/src/agentscope/models/ollama_model.py index 31b136dcb..ba854f514 100644 --- a/src/agentscope/models/ollama_model.py +++ b/src/agentscope/models/ollama_model.py @@ -3,8 +3,6 @@ from abc import ABC from typing import Sequence, Any, Optional, List, Union -from loguru import logger - from agentscope.message import MessageBase from agentscope.models import ModelWrapperBase, ModelResponse from agentscope.utils.tools import _convert_to_str @@ -170,10 +168,43 @@ def format( self, *args: Union[MessageBase, Sequence[MessageBase]], ) -> List[dict]: - """A basic strategy to format the input into the required format of - Ollama Chat API. + """Format the messages for ollama Chat API. + + All messages will be formatted into a single system message with + system prompt and dialogue history. + + Note: + 1. This strategy maybe not suitable for all scenarios, + and developers are encouraged to implement their own prompt + engineering strategies. + 2. For ollama chat api, the content field shouldn't be empty string. + + Example: + + .. code-block:: python + + prompt = model.format( + Msg("system", "You're a helpful assistant", role="system"), + Msg("Bob", "Hi, how can I help you?", role="assistant"), + Msg("user", "What's the date today?", role="user") + ) + + The prompt will be as follows: + + .. code-block:: python + + [ + { + "role": "user", + "content": ( + "You're a helpful assistant\\n\\n" + "## Dialogue History\\n" + "Bob: Hi, how can I help you?\\n" + "user: What's the date today?" + ) + } + ] - Note for ollama chat api, the content field shouldn't be empty string. Args: args (`Union[MessageBase, Sequence[MessageBase]]`): @@ -185,39 +216,61 @@ def format( `List[dict]`: The formatted messages. """ - ollama_msgs = [] - for msg in args: - if msg is None: - continue - if isinstance(msg, MessageBase): - # content shouldn't be empty string - if msg.content == "": - logger.warning( - "In ollama chat API, the content field cannot be " - "empty string. To avoid error, the empty string is " - "replaced by a blank space automatically, but the " - "model may not work as expected.", - ) - msg.content = " " - - ollama_msg = { - "role": msg.role, - "content": _convert_to_str(msg.content), - } - - # image url - if msg.url is not None: - ollama_msg["images"] = [msg.url] - ollama_msgs.append(ollama_msg) - elif isinstance(msg, list): - ollama_msgs.extend(self.format(*msg)) + # Parse all information into a list of messages + input_msgs = [] + for _ in args: + if _ is None: + continue + if isinstance(_, MessageBase): + input_msgs.append(_) + elif isinstance(_, list) and all( + isinstance(__, MessageBase) for __ in _ + ): + input_msgs.extend(_) else: raise TypeError( - f"Invalid message type: {type(msg)}, `Msg` is expected.", + f"The input should be a Msg object or a list " + f"of Msg objects, got {type(_)}.", ) - return ollama_msgs + # record dialog history as a list of strings + system_prompt = None + dialogue = [] + for i, unit in enumerate(input_msgs): + if i == 0 and unit.role == "system": + # system prompt + system_prompt = _convert_to_str(unit.content) + if not system_prompt.endswith("\n"): + system_prompt += "\n" + else: + # Merge all messages into a dialogue history prompt + dialogue.append( + f"{unit.name}: {_convert_to_str(unit.content)}", + ) + + system_content_template = [] + if system_prompt is not None: + system_content_template.append("{system_prompt}") + + if len(dialogue) != 0: + system_content_template.extend( + ["## Dialogue History", "{dialogue_history}"], + ) + + dialogue_history = "\n".join(dialogue) + + system_content_template = "\n".join(system_content_template) + + return [ + { + "role": "system", + "content": system_content_template.format( + system_prompt=system_prompt, + dialogue_history=dialogue_history, + ), + }, + ] class OllamaEmbeddingWrapper(OllamaWrapperBase):