diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 1f8c6b035..3d816ede5 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -350,7 +350,8 @@ def _get_events_for_messages(self, messages: List[dict]): # For the rest of the messages, we transform them directly into events. # TODO: Move this to separate function once more types of messages are supported. - for msg in messages[p:]: + for idx in range(p, len(messages)): + msg = messages[idx] if msg["role"] == "user": events.append( { @@ -358,6 +359,16 @@ def _get_events_for_messages(self, messages: List[dict]): "final_transcript": msg["content"], } ) + + # If it's not the last message, we also need to add the `UserMessage` event + if idx != len(messages) - 1: + events.append( + { + "type": "UserMessage", + "text": msg["content"], + } + ) + elif msg["role"] == "assistant": events.append( {"type": "StartUtteranceBotAction", "script": msg["content"]} diff --git a/tests/test_bug_5.py b/tests/test_bug_5.py new file mode 100644 index 000000000..93eddc9f6 --- /dev/null +++ b/tests/test_bug_5.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemoguardrails import RailsConfig +from tests.utils import TestChat + +config = RailsConfig.from_content( + """ +define user express greeting + "hello" + "hi" + "how are you" + +define bot express greeting + "Hey!" + +define flow greeting + user express greeting + bot express greeting +""", + yaml_content=""" + models: + - type: main + engine: nemollm + model: gpt-43b-002 + """, +) + + +def test_1(): + chat = TestChat( + config, + llm_completions=[ + " express greeting", + ], + ) + + rails = chat.app + + messages = [ + {"role": "user", "content": "Hi 1!"}, + {"role": "assistant", "content": "Hi 2!"}, + {"role": "user", "content": "Hi 3!"}, + {"role": "assistant", "content": "Hi 4!"}, + {"role": "user", "content": "Hi!"}, + ] + new_message = rails.generate(messages=messages) + + assert new_message == {"role": "assistant", "content": "Hey!"} + + info = rails.explain() + assert len(info.llm_calls) == 1 + assert "Hi 1!" in info.llm_calls[0].prompt + assert "Hi 3!" in info.llm_calls[0].prompt