diff --git a/src/mistral_common/protocol/instruct/validator.py b/src/mistral_common/protocol/instruct/validator.py index ab3232d..c63339d 100644 --- a/src/mistral_common/protocol/instruct/validator.py +++ b/src/mistral_common/protocol/instruct/validator.py @@ -192,7 +192,7 @@ def _validate_message_order(self, messages: List[ChatMessage]) -> None: elif previous_role == Roles.user: expected_roles = {Roles.assistant, Roles.system, Roles.user} elif previous_role == Roles.assistant: - expected_roles = {Roles.user, Roles.tool} + expected_roles = {Roles.assistant, Roles.user, Roles.tool} elif previous_role == Roles.tool: expected_roles = {Roles.assistant, Roles.tool}