Skip to content

Commit

Permalink
Merge pull request #1072 from MB-Finski/main
Browse files Browse the repository at this point in the history
Fix: append correct latest user message when using Ollama
  • Loading branch information
ashpreetbedi authored Oct 3, 2024
2 parents 8f81679 + e077352 commit 6d322fe
Showing 1 changed file with 23 additions and 17 deletions.
40 changes: 23 additions & 17 deletions phi/llm/ollama/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,18 @@ def deactivate_function_calls(self) -> None:
# This is triggered when the function call limit is reached.
self.format = ""

def response(self, messages: List[Message]) -> str:
def response(self, messages: List[Message], current_user_query: str | None = None) -> str:
logger.debug("---------- Ollama Response Start ----------")
# -*- Log messages for debugging
for m in messages:
m.log()

if current_user_query is None:
for m in reversed(messages):
if m.role == "user":
current_user_query = m.content
break

response_timer = Timer()
response_timer.start()
response: Mapping[str, Any] = self.invoke(messages=messages)
Expand Down Expand Up @@ -196,7 +202,7 @@ def response(self, messages: List[Message]) -> str:
messages = self.add_tool_call_error_message(messages)

# -*- Yield new response using results of tool calls
final_response += self.response(messages=messages)
final_response += self.response(messages=messages, current_user_query=current_user_query)
return final_response

elif assistant_message.tool_calls is not None and self.run_tools:
Expand Down Expand Up @@ -224,7 +230,7 @@ def response(self, messages: List[Message]) -> str:

# This case rarely happens but it should be handled
if len(function_calls_to_run) != len(function_call_results):
return final_response + self.response(messages=messages)
return final_response + self.response(messages=messages, current_user_query=current_user_query)

# Add results of the function calls to the messages
elif len(function_call_results) > 0:
Expand All @@ -234,14 +240,14 @@ def response(self, messages: List[Message]) -> str:
if any(item.tool_call_error for item in function_call_results):
messages = self.add_tool_call_error_message(messages)
else:
messages = self.add_original_user_message(messages)
messages = self.add_original_user_message(messages, current_user_query)

# Deactivate tool calls by turning off JSON mode after 1 tool call
if self.deactivate_tools_after_use:
self.deactivate_function_calls()

# -*- Yield new response using results of tool calls
final_response += self.response(messages=messages)
final_response += self.response(messages=messages, current_user_query=current_user_query)
return final_response

logger.debug("---------- Ollama Response End ----------")
Expand All @@ -258,6 +264,12 @@ def response_stream(self, messages: List[Message]) -> Iterator[str]:
for m in messages:
m.log()

original_user_message_content = None
for m in reversed(messages):
if m.role == "user":
original_user_message_content = m.content
break

assistant_message_content = ""
response_is_tool_call = False
tool_call_bracket_count = 0
Expand Down Expand Up @@ -407,7 +419,7 @@ def response_stream(self, messages: List[Message]) -> Iterator[str]:
messages = self.add_tool_call_error_message(messages)

# -*- Yield new response using results of tool calls
yield from self.response_stream(messages=messages)
yield from self.response_stream(messages=messages, current_user_query=current_user_query)

elif assistant_message.tool_calls is not None and self.run_tools:
function_calls_to_run: List[FunctionCall] = []
Expand Down Expand Up @@ -445,32 +457,26 @@ def response_stream(self, messages: List[Message]) -> Iterator[str]:
if any(item.tool_call_error for item in function_call_results):
messages = self.add_tool_call_error_message(messages)
else:
messages = self.add_original_user_message(messages)
messages = self.add_original_user_message(messages,original_user_message_content)

# Deactivate tool calls by turning off JSON mode after 1 tool call
if self.deactivate_tools_after_use:
self.deactivate_function_calls()

# -*- Yield new response using results of tool calls
yield from self.response_stream(messages=messages)
yield from self.response_stream(messages=messages, current_user_query=current_user_query)

logger.debug("---------- Ollama Response End ----------")

def add_original_user_message(self, messages: List[Message]) -> List[Message]:
def add_original_user_message(self, messages: List[Message], current_user_query: str | None = None) -> List[Message]:
# Add the original user message to the messages to remind the LLM of the original task
original_user_message_content = None
for m in messages:
if m.role == "user":
original_user_message_content = m.content
break

if original_user_message_content is not None:
if current_user_query is not None:
_content = (
"Using the results of the tools above, respond to the following message. "
"If the user explicitly requests raw data or specific formats like JSON, provide it as requested. "
"Otherwise, use the tool results to provide a clear and relevant answer without "
"returning the raw results directly:"
f"\n\n<user_message>\n{original_user_message_content}\n</user_message>"
f"\n\n<user_message>\n{current_user_query}\n</user_message>"
)

messages.append(Message(role="user", content=_content))
Expand Down

0 comments on commit 6d322fe

Please sign in to comment.