diff --git a/pyproject.toml b/pyproject.toml index 46da2e0..f311eeb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "mistral_common" -version = "1.4.4" +version = "1.5.0" description = "" authors = ["bam4d "] readme = "README.md" diff --git a/src/mistral_common/__init__.py b/src/mistral_common/__init__.py index c0f285b..5b60188 100644 --- a/src/mistral_common/__init__.py +++ b/src/mistral_common/__init__.py @@ -1 +1 @@ -__version__ = "1.4.4" +__version__ = "1.5.0" diff --git a/src/mistral_common/data/mistral_instruct_tokenizer_241114.model.v7m1 b/src/mistral_common/data/mistral_instruct_tokenizer_241114.model.v7m1 new file mode 100644 index 0000000..d23856f Binary files /dev/null and b/src/mistral_common/data/mistral_instruct_tokenizer_241114.model.v7m1 differ diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py index a151855..2852f96 100644 --- a/src/mistral_common/protocol/instruct/normalize.py +++ b/src/mistral_common/protocol/instruct/normalize.py @@ -19,7 +19,7 @@ from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.instruct.tool_calls import FunctionCall, Tool, ToolCall from mistral_common.tokens.instruct.request import InstructRequest -from mistral_common.tokens.tokenizers.base import InstructRequestType +from mistral_common.tokens.tokenizers.base import InstructRequestType, TokenizerVersion class InstructRequestNormalizer( @@ -35,6 +35,8 @@ class InstructRequestNormalizer( - Normalize tool calls """ + system_prompt_in_begin: bool = False + def __init__( self, user_message_class: Type[UserMessageType], @@ -117,7 +119,7 @@ def _aggregate_assistant_messages(self, messages: List[UATS]) -> AssistantMessag weight: Optional[float] = None for message in messages: assert isinstance(message, self._assistant_message_class), "Expected assistant message" - if message.tool_calls is not None: + if message.tool_calls is not None and len(message.tool_calls) > 0: for tool_call in message.tool_calls: normalized_tool_call = self._normalize_tool_call(tool_call) tool_calls.append(normalized_tool_call) @@ -205,7 +207,9 @@ def _aggregate_messages(self, request: ChatCompletionRequest[UATS]) -> List[UATS # If the first message is not a user message, or we didnt aggregate # anything (all system messages) for example, add an empty user message - if len(aggregated_messages) == 0 or aggregated_messages[0].role != Roles.user: + if len(aggregated_messages) == 0 or ( + not self.system_prompt_in_begin and aggregated_messages[0].role != Roles.user + ): aggregated_messages.insert(0, self._user_message_class(content="")) return aggregated_messages @@ -217,3 +221,45 @@ def from_chat_completion_request(self, request: ChatCompletionRequest[UATS]) -> return self._instruct_request_class( messages=messages, system_prompt=system_prompt, available_tools=request.tools ) + + +class InstructRequestNormalizerV7(InstructRequestNormalizer): + system_prompt_in_begin: bool = True + + @staticmethod + def normalizer() -> "InstructRequestNormalizerV7": + return InstructRequestNormalizerV7( + UserMessage, + AssistantMessage, + ToolMessage, + SystemMessage, + InstructRequest[UATS, Tool], + ) + + def _aggregate_role(self, messages: List[UATS], role: Optional[Roles]) -> Sequence[UATS]: + if role == Roles.tool: + return self._aggregate_tool_messages(messages) + elif role == Roles.assistant: + return [self._aggregate_assistant_messages(messages)] + elif role == Roles.user: + return [self._aggregate_user_messages(messages)] + elif role == Roles.system: + return messages + else: + assert role is None and len(messages) == 0 + return [] + + def _aggregate_system_prompts(self, request: ChatCompletionRequest[UATS]) -> Optional[str]: + raise NotImplementedError("We should not aggregate system prompts") + + def from_chat_completion_request(self, request: ChatCompletionRequest[UATS]) -> InstructRequestType: # type: ignore[type-var] + messages = self._aggregate_messages(request) + return self._instruct_request_class(messages=messages, system_prompt=None, available_tools=request.tools) # type: ignore[no-any-return] + + +def normalizer_for_tokenizer_version(version: TokenizerVersion) -> InstructRequestNormalizer: + if version in {TokenizerVersion.v1, TokenizerVersion.v2, TokenizerVersion.v3}: + return InstructRequestNormalizer.normalizer() + elif version == TokenizerVersion.v7: + return InstructRequestNormalizerV7.normalizer() + raise ValueError(f"Unknown tokenizer version {version}") diff --git a/src/mistral_common/protocol/instruct/request.py b/src/mistral_common/protocol/instruct/request.py index 2704de5..4c4ee90 100644 --- a/src/mistral_common/protocol/instruct/request.py +++ b/src/mistral_common/protocol/instruct/request.py @@ -24,3 +24,4 @@ class ChatCompletionRequest(BaseCompletionRequest, Generic[ChatMessageType]): response_format: ResponseFormat = Field(default_factory=ResponseFormat) tools: Optional[List[Tool]] = None tool_choice: ToolChoice = ToolChoice.auto + truncate_for_context_length: bool = False diff --git a/src/mistral_common/tokens/instruct/request.py b/src/mistral_common/tokens/instruct/request.py index fec27ff..95ebdb3 100644 --- a/src/mistral_common/tokens/instruct/request.py +++ b/src/mistral_common/tokens/instruct/request.py @@ -22,3 +22,4 @@ class InstructRequest(MistralBase, Generic[ChatMessageType, ToolType]): messages: List[ChatMessageType] system_prompt: Optional[str] = None available_tools: Optional[List[ToolType]] = None + truncate_at_max_tokens: Optional[int] = None diff --git a/src/mistral_common/tokens/tokenizers/base.py b/src/mistral_common/tokens/tokenizers/base.py index 3853a4b..63b6040 100644 --- a/src/mistral_common/tokens/tokenizers/base.py +++ b/src/mistral_common/tokens/tokenizers/base.py @@ -34,12 +34,16 @@ class SpecialTokens(str, Enum): prefix = "[PREFIX]" middle = "[MIDDLE]" suffix = "[SUFFIX]" + begin_system = "[SYSTEM_PROMPT]" + end_system = "[/SYSTEM_PROMPT]" + begin_tool_content = "[TOOL_CONTENT]" class TokenizerVersion(str, Enum): v1 = "v1" # vocab_size = 32000 v2 = "v2" # vocab_size = 32768 with special control tokens [INST], [\INST] v3 = "v3" # vocab_size = 32768 (spm) OR 128000 (tekken) with improved function calling + v7 = "v7" # vocab_size = 32768 (spm) or 128000 (tekken) with improved system prompt and function calling class Tokenized(MistralBase): diff --git a/src/mistral_common/tokens/tokenizers/mistral.py b/src/mistral_common/tokens/tokenizers/mistral.py index 188a127..dc38ec5 100644 --- a/src/mistral_common/tokens/tokenizers/mistral.py +++ b/src/mistral_common/tokens/tokenizers/mistral.py @@ -1,7 +1,7 @@ from __future__ import annotations from pathlib import Path -from typing import Callable, Dict, Generic, List, Union +from typing import Callable, Dict, Generic, List, Optional, Union from mistral_common.exceptions import ( TokenizerException, @@ -13,7 +13,7 @@ ToolMessageType, UserMessageType, ) -from mistral_common.protocol.instruct.normalize import InstructRequestNormalizer +from mistral_common.protocol.instruct.normalize import InstructRequestNormalizer, normalizer_for_tokenizer_version from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.instruct.validator import ( MistralRequestValidator, @@ -39,13 +39,17 @@ InstructTokenizerV1, InstructTokenizerV2, InstructTokenizerV3, + InstructTokenizerV7, SentencePieceTokenizer, + get_mm_config, is_sentencepiece, ) from mistral_common.tokens.tokenizers.tekken import Tekkenizer, is_tekken -def load_mm_encoder(mm_config: MultimodalConfig, tokenizer: Tekkenizer) -> MultiModalEncoder: +def load_mm_encoder( + mm_config: MultimodalConfig, tokenizer: Union[Tekkenizer, SentencePieceTokenizer] +) -> MultiModalEncoder: special_ids = SpecialImageIDs( img=tokenizer.get_control_token(SpecialTokens.img.value), img_break=tokenizer.get_control_token(SpecialTokens.img_break.value), @@ -99,6 +103,13 @@ def v3(cls, is_tekken: bool = False, is_mm: bool = False) -> "MistralTokenizer": return cls.from_file(str(cls._data_path() / tokenizer_name), mode=ValidationMode.test) + @classmethod + def v7(cls) -> "MistralTokenizer": + """mistral-large 2.1""" + return cls.from_file( + str(cls._data_path() / "mistral_instruct_tokenizer_241114.model.v7m1"), mode=ValidationMode.test + ) + @classmethod def from_model(cls, model: str) -> "MistralTokenizer": model_name_to_tokenizer_cls: Dict[str, Callable[[], MistralTokenizer]] = { @@ -136,14 +147,15 @@ def from_file( if is_tekken(tokenizer_filename): tokenizer = Tekkenizer.from_file(tokenizer_filename) mm_config = tokenizer.multimodal - mm_encoder = load_mm_encoder(mm_config, tokenizer) if mm_config is not None else None elif is_sentencepiece(tokenizer_filename): tokenizer = SentencePieceTokenizer(tokenizer_filename) - mm_encoder = None + mm_config = get_mm_config(tokenizer_filename) else: raise TokenizerException(f"Unrecognized tokenizer file: {tokenizer_filename}") - request_normalizer = InstructRequestNormalizer.normalizer() + mm_encoder = load_mm_encoder(mm_config, tokenizer) if mm_config is not None else None + + request_normalizer = normalizer_for_tokenizer_version(tokenizer.version) if tokenizer.version == TokenizerVersion.v1: assert mm_encoder is None, "Tokenizer version needs to be >= v3" @@ -165,14 +177,35 @@ def from_file( validator=MistralRequestValidatorV3(mode=mode), request_normalizer=request_normalizer, ) + elif tokenizer.version == TokenizerVersion.v7: + return MistralTokenizer( + InstructTokenizerV7(tokenizer, mm_encoder=mm_encoder), + validator=MistralRequestValidatorV3(mode=mode), + request_normalizer=request_normalizer, + ) else: raise TokenizerException(f"Unrecognized tokenizer filename: {tokenizer_filename}") raise TokenizerException(f"Unrecognized tokenizer version: {tokenizer.version}") - def encode_chat_completion(self, request: ChatCompletionRequest[UATS]) -> TokenizedType: + def encode_chat_completion( + self, request: ChatCompletionRequest[UATS], max_model_input_len: Optional[int] = None + ) -> TokenizedType: validated_request = self._chat_completion_request_validator.validate_request(request) + + if max_model_input_len is None and request.truncate_for_context_length: + # the max_model_input_len arg should not be optionnal ; + # but this function is used in many small scripts that have no use + # for truncation, and don't provide the max model len + raise TokenizerException( + "encoding a chat completion request with truncation, but no max model len was provided", + ) + instruct_request = self._instruct_request_normalizer.from_chat_completion_request(validated_request) + + if request.truncate_for_context_length: + instruct_request.truncate_at_max_tokens = max_model_input_len + return self.instruct_tokenizer.encode_instruct(instruct_request) def encode_fim(self, request: FIMRequest) -> TokenizedType: diff --git a/src/mistral_common/tokens/tokenizers/multimodal.py b/src/mistral_common/tokens/tokenizers/multimodal.py index 0498213..c7542da 100644 --- a/src/mistral_common/tokens/tokenizers/multimodal.py +++ b/src/mistral_common/tokens/tokenizers/multimodal.py @@ -1,6 +1,7 @@ import base64 import logging from dataclasses import dataclass +from enum import Enum from io import BytesIO from typing import Tuple, Union @@ -57,6 +58,18 @@ def image_from_chunk(chunk: Union[ImageURLChunk, ImageChunk]) -> SerializableIma DATASET_STD = (0.26862954, 0.26130258, 0.27577711) # RGB +# only relevant for spm +class MultiModalVersion(str, Enum): + m1 = "m1" + + @property + def config(self) -> "MultimodalConfig": + if self.name == "m1": + return MultimodalConfig(16, 1024) + + raise NotImplementedError(f"{self.name}") + + @dataclass class MultimodalConfig: image_patch_size: int diff --git a/src/mistral_common/tokens/tokenizers/sentencepiece.py b/src/mistral_common/tokens/tokenizers/sentencepiece.py index 98a442b..6b92411 100644 --- a/src/mistral_common/tokens/tokenizers/sentencepiece.py +++ b/src/mistral_common/tokens/tokenizers/sentencepiece.py @@ -14,6 +14,7 @@ AssistantMessage, AssistantMessageType, ContentChunk, + SystemMessage, TextChunk, ToolMessage, UserMessage, @@ -30,19 +31,22 @@ Tokenizer, TokenizerVersion, ) -from mistral_common.tokens.tokenizers.multimodal import MultiModalEncoder +from mistral_common.tokens.tokenizers.multimodal import MultimodalConfig, MultiModalEncoder, MultiModalVersion def is_sentencepiece(path: Union[str, Path]) -> bool: if isinstance(path, str): path = Path(path) - suffixes = [f".model.{v}" for v in list(TokenizerVersion.__members__)] + [".model"] + instruct_versions = list(TokenizerVersion.__members__) + mm_versions = list(MultiModalVersion.__members__) + [""] # allow no mm version + suffixes = [f".model.{v}{m}" for v in instruct_versions for m in mm_versions] + [".model"] + return path.is_file() and any(path.name.endswith(suffix) for suffix in suffixes) def get_spm_version(tokenizer_filename: str, raise_deprecated: bool = False) -> TokenizerVersion: - _version_str = tokenizer_filename.split(".")[-1] + _version_str = tokenizer_filename.split(".")[-1].split("m")[0] if _version_str == "model": if raise_deprecated: raise TokenizerException(f"Make sure to rename your tokenizer file to end with {tokenizer_filename}.v1.") @@ -56,6 +60,19 @@ def get_spm_version(tokenizer_filename: str, raise_deprecated: bool = False) -> return TokenizerVersion(_version_str) +def get_mm_config(tokenizer_filename: str) -> Optional[MultimodalConfig]: + _version_str = tokenizer_filename.split(".")[-1] + if "m" not in _version_str: + return None + + _mm_version_str = "m" + _version_str.split("m")[-1] + + if _mm_version_str not in MultiModalVersion.__members__: + raise TokenizerException(f"Unrecognized tokenizer filename: {tokenizer_filename}") + + return MultiModalVersion(_mm_version_str).config + + class SentencePieceTokenizer(Tokenizer): def __init__(self, model_path: str, tokenizer_version: Optional[TokenizerVersion] = None) -> None: self._logger = logging.getLogger(self.__class__.__name__) @@ -173,14 +190,24 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: def encode_assistant_message(self, message: AssistantMessageType, is_before_last_user_message: bool) -> List[int]: raise NotImplementedError("Assistant message not implemented") + def _truncate_for_max_tokens( + self, + tokenized: List[Optional[List[int]]], + messages: List[AssistantMessageType], + max_tokens: int, + last_user_message_index: int, + ) -> None: + # Tokenizer ⩽ V3 does not support truncation + return + def encode_instruct( self, request: InstructRequest[AssistantMessageType, Tool], ) -> Tokenized: # init at bos - tokens = self.start() images: List[np.ndarray] = [] prefix_ids: Optional[List[int]] = None + tokens_list: List[Optional[List[int]]] = [] # find last user message first_user_idx, last_user_idx = self.find_first_last_user(request) @@ -201,7 +228,23 @@ def encode_instruct( new_tokens = self.encode_assistant_message(msg, msg_idx < last_user_idx) if msg_idx == len(request.messages) - 1: prefix_ids = new_tokens - tokens.extend(new_tokens) + elif isinstance(msg, SystemMessage): + new_tokens = self.encode_system_message(msg) + + tokens_list.append(new_tokens) + + if request.truncate_at_max_tokens is not None: + self._truncate_for_max_tokens( + tokens_list, + request.messages, + request.truncate_at_max_tokens, + last_user_idx, + ) + tokens = self.start() + + for tok in tokens_list: + if tok is not None: + tokens.extend(tok) return Tokenized( tokens=tokens, @@ -363,26 +406,32 @@ def _prepare_function_call(self, tool_call: ToolCall) -> Dict[str, Any]: "arguments": self._parse_json_content(tool_call.function.arguments), } + def _encode_normal_content_assistant_message(self, message: AssistantMessageType) -> List[int]: + assert message.content, f"Assistant message must have content. Got {message}" + return self.tokenizer.encode(message.content.rstrip(" "), bos=False, eos=False) + + def _encode_tool_calls_in_assistant_message(self, message: AssistantMessageType) -> List[int]: + assert message.tool_calls, f"Assistant message must have tool calls. Got {message}" + prepared_tool_calls = [] + for tool_call in message.tool_calls: + prepared_tool_calls.append(self._prepare_function_call(tool_call)) + tool_call_str = json.dumps(prepared_tool_calls, ensure_ascii=False) + curr_tokens = [ + self.TOOL_CALLS, + *self.tokenizer.encode(tool_call_str, bos=False, eos=False), + ] + return curr_tokens + def encode_assistant_message(self, message: AssistantMessageType, is_before_last_user_message: bool) -> List[int]: - if message.tool_calls is not None and len(message.tool_calls) > 0: + if message.tool_calls: if is_before_last_user_message: # don't tokenize tool call before last user message return [] - - prepared_tool_calls = [] - for tool_call in message.tool_calls: - prepared_tool_calls.append(self._prepare_function_call(tool_call)) - - tool_call_str = json.dumps(prepared_tool_calls, ensure_ascii=False) - curr_tokens = [ - self.TOOL_CALLS, - *self.tokenizer.encode(tool_call_str, bos=False, eos=False), - ] + curr_tokens = self._encode_tool_calls_in_assistant_message(message) elif message.content: - curr_tokens = self.tokenizer.encode(message.content, bos=False, eos=False) + curr_tokens = self._encode_normal_content_assistant_message(message) else: raise TokenizerException(f"Invalid assistant message: {message.content}") - if not message.prefix: curr_tokens.append(self.tokenizer.eos_id) return curr_tokens @@ -496,3 +545,128 @@ def encode_user_content( images.append(img_encoding.image) return tokens, images + + +class InstructTokenizerV7(InstructTokenizerV3): + """ + The difference with V3 tokenizer is that it encodes the system prompts differently: + - in V7 the system prompts are treated as separate SystemMessages + - they are no longer prepended to the last user message + - they are printed between special tokens + Tool call results are encoded as : + - [begin tool call] call_id_tokens [tool_content] content tokens [end tool call] + """ + + def __init__(self, tokenizer: Tokenizer, mm_encoder: Optional[MultiModalEncoder] = None) -> None: + super().__init__(tokenizer, mm_encoder) + self.BEGIN_SYSTEM = self.tokenizer.get_control_token(SpecialTokens.begin_system.value) + self.END_SYSTEM = self.tokenizer.get_control_token(SpecialTokens.end_system.value) + self.BEGIN_TOOL_CONTENT = self.tokenizer.get_control_token(SpecialTokens.begin_tool_content.value) + + def _truncate_for_max_tokens( + self, + tokenized_messages: List[Optional[List[int]]], + messages: List[AssistantMessageType], + max_tokens: int, + last_user_message_index: int, + ) -> None: + # drop some messages to fit in max_tokens. Rules: + # - don't drop any system messages + # - when a user message is dropped, all following assistant|tool message should be dropped until the next + # user message + # - we never drop the last message + to_drop = sum(len(t) for t in tokenized_messages if t is not None) - max_tokens + + def drop(idx: int) -> None: + nonlocal to_drop + if isinstance(messages[idx], SystemMessage): + # never drop system messages + return + if idx == last_user_message_index: + # never drop the last user message + return + tok = tokenized_messages[idx] + assert tok is not None + to_drop -= len(tok) + tokenized_messages[idx] = None + + current_idx = 0 + while to_drop > 0 and current_idx < len(messages): + drop(current_idx) + current_idx += 1 + if isinstance(messages[current_idx - 1], UserMessage): + # if we just dropped a UserMessage, + # also drop everything until the next user message + while current_idx < len(messages) and not isinstance(messages[current_idx], UserMessage): + drop(current_idx) + current_idx += 1 + + if to_drop > 0: + raise TokenizerException("Input couldn't fit in truncate_at_max_token") + + def encode_system_message(self, message: SystemMessage) -> List[int]: + assert message.content is not None + assert isinstance(message.content, str), "Message content must be normalized" + tokens = [ + self.BEGIN_SYSTEM, + *self.tokenizer.encode(message.content, bos=False, eos=False), + self.END_SYSTEM, + ] + return tokens + + def encode_user_message( + self, + message: UserMessage, + available_tools: Optional[List[Tool]], + is_last: bool, + is_first: bool, + system_prompt: Optional[str] = None, + force_img_first: bool = False, + ) -> Tuple[List[int], List[np.ndarray]]: + assert system_prompt is None, "in Tokenizer V7 we don't encode system prompts in user messages" + return super().encode_user_message( + message, + available_tools, + is_last=is_last, + is_first=is_first, + system_prompt=None, + force_img_first=force_img_first, + ) + + def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: bool) -> List[int]: + """ + Same as V3 but tools not wrapped in a list and history is tokenized also + """ + assert message.tool_call_id is not None + tool_call_id_tokens = self.tokenizer.encode(message.tool_call_id, bos=False, eos=False) + tokens = self.tokenizer.encode(message.content, bos=False, eos=False) + + prefix_tokens = [ + self.BEGIN_TOOL_RESULTS, + *tool_call_id_tokens, + self.BEGIN_TOOL_CONTENT, + ] + curr_tokens = [ + *prefix_tokens, + *tokens, + self.END_TOOL_RESULTS, + ] + return curr_tokens + + def encode_assistant_message(self, message: AssistantMessageType, is_before_last_user_message: bool) -> List[int]: + if not message.content and not message.tool_calls: + raise TokenizerException(f"Invalid assistant message: {message}") + curr_tokens: list = [] + if message.content: + if isinstance(message.content, str): + curr_tokens += self._encode_normal_content_assistant_message(message) + elif isinstance(message.content, list): + curr_tokens += self.encode_content_chunks( + message.content, is_last=False, system_prompt=None, force_img_first=True + ).tokens + if message.tool_calls: + curr_tokens += self._encode_tool_calls_in_assistant_message(message) + if not message.prefix: + curr_tokens.append(self.tokenizer.eos_id) + + return curr_tokens diff --git a/src/mistral_common/tokens/tokenizers/tekken.py b/src/mistral_common/tokens/tokenizers/tekken.py index d2232ee..da6859c 100644 --- a/src/mistral_common/tokens/tokenizers/tekken.py +++ b/src/mistral_common/tokens/tokenizers/tekken.py @@ -75,6 +75,9 @@ class Tekkenizer(Tokenizer): SpecialTokens.prefix, SpecialTokens.middle, SpecialTokens.suffix, + SpecialTokens.begin_system, + SpecialTokens.end_system, + SpecialTokens.begin_tool_content, ) SPECIAL_TOKEN_TEMPLATE = "" diff --git a/tests/test_normalization.py b/tests/test_normalization.py index 9637a18..6bbe311 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -12,7 +12,7 @@ ToolMessage, UserMessage, ) -from mistral_common.protocol.instruct.normalize import InstructRequestNormalizer +from mistral_common.protocol.instruct.normalize import InstructRequestNormalizer, InstructRequestNormalizerV7 from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.instruct.tool_calls import Function, FunctionCall, Tool, ToolCall from mistral_common.tokens.instruct.request import InstructRequest @@ -23,6 +23,10 @@ class TestChatCompletionRequestNormalization: def normalizer(self) -> InstructRequestNormalizer: return InstructRequestNormalizer(UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest) + @pytest.fixture(autouse=True) + def normalizer_v7(self) -> InstructRequestNormalizerV7: + return InstructRequestNormalizerV7(UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest) + def mock_chat_completion(self, messages: List[ChatMessage]) -> ChatCompletionRequest: return ChatCompletionRequest( model="test", @@ -77,13 +81,32 @@ def test_system_assistant_user(self, normalizer: InstructRequestNormalizer) -> N parsed_request = normalizer.from_chat_completion_request(chat_completion_request) - assert parsed_request.system_prompt == "S" - first_message = parsed_request.messages[0] assert isinstance(first_message, UserMessage) assert first_message.content == "" assert parsed_request.system_prompt == "S" + def test_system_assistant_user_v7(self, normalizer_v7: InstructRequestNormalizerV7) -> None: + chat_completion_request = self.mock_chat_completion( + messages=[ + SystemMessage(content="S"), + AssistantMessage(content="A"), + UserMessage(content="U"), + ] + ) + + parsed_request: InstructRequest = normalizer_v7.from_chat_completion_request(chat_completion_request) + + first_message = parsed_request.messages[0] + assert isinstance(first_message, SystemMessage) + assert first_message.content == "S" + + second_message = parsed_request.messages[1] + assert isinstance(second_message, AssistantMessage) + assert second_message.content == "A" + + assert parsed_request.system_prompt is None + def test_assistant_system_user_adds_user(self, normalizer: InstructRequestNormalizer) -> None: chat_completion_request = self.mock_chat_completion( messages=[ @@ -104,6 +127,28 @@ def test_assistant_system_user_adds_user(self, normalizer: InstructRequestNormal assert first_message.content == "" assert parsed_request.system_prompt == "S" + def test_assistant_assistant_system_v7(self, normalizer_v7: InstructRequestNormalizer) -> None: + chat_completion_request = self.mock_chat_completion( + messages=[ + AssistantMessage(content="A"), + SystemMessage(content="S"), + ] + ) + + parsed_request = normalizer_v7.from_chat_completion_request(chat_completion_request) + + assert parsed_request.system_prompt is None + + assert len(parsed_request.messages) == 2 + + first_message = parsed_request.messages[0] + assert isinstance(first_message, AssistantMessage) + assert first_message.content == "A" + + second_message = parsed_request.messages[1] + assert isinstance(second_message, SystemMessage) + assert second_message.content == "S" + def check_merge( self, roles: List[str], diff --git a/tests/test_tokenize_v3.py b/tests/test_tokenize_v3.py index eb16d37..d3c6c14 100644 --- a/tests/test_tokenize_v3.py +++ b/tests/test_tokenize_v3.py @@ -36,7 +36,7 @@ def tekken_tokenizer() -> InstructTokenizer: def test_is_spm() -> None: # this is valid - for suffix in list(TokenizerVersion.__members__): + for suffix in list(TokenizerVersion.__members__) + ["v3m1"]: with NamedTemporaryFile(suffix=".model." + suffix) as f: assert is_sentencepiece(f.name) @@ -49,13 +49,12 @@ def test_is_spm() -> None: def test_spm_version() -> None: - directory = Path(__file__).parent / "data" + directory = Path(__file__).parent.parent / "src" / "mistral_common" / "data" for file in directory.iterdir(): if not file.is_file() or str(file).endswith(".json"): continue - suffix = file.suffix[1:] - print(suffix) + suffix = file.suffix[1:].split("m")[0] if suffix == "model": assert SentencePieceTokenizer(str(file)).version == TokenizerVersion.v1 else: diff --git a/tests/test_tokenizer_v7.py b/tests/test_tokenizer_v7.py new file mode 100644 index 0000000..e10350e --- /dev/null +++ b/tests/test_tokenizer_v7.py @@ -0,0 +1,304 @@ +import json +from typing import List + +import pytest +from mistral_common.exceptions import TokenizerException +from mistral_common.protocol.instruct.messages import ( + AssistantMessage, + ChatMessage, + ImageChunk, + SystemMessage, + TextChunk, + ToolMessage, + UserMessage, +) +from mistral_common.protocol.instruct.request import ChatCompletionRequest +from mistral_common.protocol.instruct.tool_calls import Function, FunctionCall, Tool, ToolCall +from mistral_common.tokens.tokenizers.base import InstructRequest, TokenizerVersion +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer +from mistral_common.tokens.tokenizers.multimodal import ImageEncoder +from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerV7 +from mistral_common.tokens.tokenizers.tekken import Tekkenizer +from PIL import Image + +from tests.test_tekken import _quick_vocab + + +@pytest.fixture(scope="session") +def tekkenizer() -> InstructTokenizerV7: + tokenizer = Tekkenizer( + _quick_vocab([b"a", b"b", b"c", b"f", b"de"]), + pattern=r".+", # single token, whole string + vocab_size=256 + 100, + num_special_tokens=100, + version=TokenizerVersion.v7, + ) + return InstructTokenizerV7(tokenizer) + + +@pytest.fixture(scope="session") +def spm_tokenizer() -> InstructTokenizerV7: + tokenizer = MistralTokenizer.v7().instruct_tokenizer + mm_encoder = tokenizer.mm_encoder + assert isinstance(mm_encoder, ImageEncoder) + # hardcode image_patch_size = 2 for easier checks + mm_encoder.mm_config.image_patch_size = 2 + return tokenizer # type: ignore + + +def test_tokenize_assistant_message(spm_tokenizer: InstructTokenizerV7) -> None: + tokenized = spm_tokenizer.encode_instruct( + InstructRequest( + messages=[ + UserMessage( + content=[ + TextChunk( + text="a", + ), + ImageChunk(image=Image.new("RGB", (4, 4), "red")), + ] + ), + AssistantMessage(content="b"), + ToolMessage(tool_call_id="b", content="f"), + ], + ) + ) + _im = 10 + _im_break = 14 + _im_end = 15 + img_tokens = [_im, _im, _im_break, _im, _im, _im_end] + assert tokenized.tokens == [ + 1, # bos + 3, # begin_inst + *img_tokens, + 1032, # a + 4, # end_inst + 1055, # b + 2, # eos + 8, # [TOOL_RESULTS] + 1055, # tool_call_id b + 18, # [TOOL_CONTENT] + 1053, # f + 9, # [/TOOL_RESULTS] + ] + assert ( + tokenized.text + == "[INST][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]▁a[/INST]▁b[TOOL_RESULTS]▁b[TOOL_CONTENT]▁f[/TOOL_RESULTS]" # noqa + ) + + +@pytest.mark.parametrize( + "messages, expected_text", + [ + ( + [ + SystemMessage(content="a"), + UserMessage(content="a"), + AssistantMessage( + content="b", + tool_calls=[ + ToolCall( + function=FunctionCall( + name="t", + arguments=json.dumps( + { + "g": "h", + }, + ensure_ascii=False, + ), + ), + ), + ], + ), + ], + '[SYSTEM_PROMPT]▁a[/SYSTEM_PROMPT][AVAILABLE_TOOLS]▁[{"type":▁"function",▁"function":▁{"name":▁"t",▁"description":▁"",▁"parameters":▁{"type":▁"object",▁"properties":▁{"g":▁{"type":▁"string"},▁"h":▁{"type":▁"string"}}}}}][/AVAILABLE_TOOLS][INST]▁a[/INST]▁b[TOOL_CALLS]▁[{"name":▁"t",▁"arguments":▁{"g":▁"h"}}]', # noqa + ), + ( + [ + SystemMessage(content="a"), + UserMessage(content="a"), + UserMessage(content="c"), + AssistantMessage( + content="b", + tool_calls=[ + ToolCall( + function=FunctionCall( + name="t", + arguments=json.dumps( + { + "g": "h", + }, + ensure_ascii=False, + ), + ), + ), + ], + ), + ToolMessage(content="b", tool_call_id="1234"), + ], + '[SYSTEM_PROMPT]▁a[/SYSTEM_PROMPT][INST]▁a[/INST][AVAILABLE_TOOLS]▁[{"type":▁"function",▁"function":▁{"name":▁"t",▁"description":▁"",▁"parameters":▁{"type":▁"object",▁"properties":▁{"g":▁{"type":▁"string"},▁"h":▁{"type":▁"string"}}}}}][/AVAILABLE_TOOLS][INST]▁c[/INST]▁b[TOOL_CALLS]▁[{"name":▁"t",▁"arguments":▁{"g":▁"h"}}][TOOL_RESULTS]▁1234[TOOL_CONTENT]▁b[/TOOL_RESULTS]', # noqa + ), + ], +) +def test_encode_spm(spm_tokenizer: InstructTokenizerV7, messages: List[ChatMessage], expected_text: str) -> None: + tokenized = spm_tokenizer.encode_instruct( + InstructRequest( + available_tools=[ + Tool( + function=Function( + name="t", + parameters={ + "type": "object", + "properties": { + "g": {"type": "string"}, + "h": {"type": "string"}, + }, + }, + ) + ), + ], + messages=messages, + ) + ) + + assert tokenized.text == expected_text, f"{tokenized.text} != {expected_text}" + + +def test_encode_chat_completion() -> None: + tokenizer = MistralTokenizer.v7() + + request: ChatCompletionRequest = ChatCompletionRequest( + tools=[ + Tool( + function=Function( + name="t", + parameters={ + "type": "object", + "properties": { + "g": {"type": "string"}, + "h": {"type": "string"}, + }, + }, + ) + ), + ], + messages=[ + SystemMessage(content="a"), + UserMessage( + content=[ + TextChunk( + text="a", + ), + ImageChunk(image=Image.new("RGB", (4, 4), "red")), + ] + ), + AssistantMessage(content="b"), + ToolMessage(tool_call_id="123456789", content="f"), + ], + ) + + encoded = tokenizer.encode_chat_completion(request) + + assert len(encoded.images) == 1 + assert encoded.images[0].shape == (3, 16, 16) + assert ( + encoded.text + == '[SYSTEM_PROMPT]▁a[/SYSTEM_PROMPT][AVAILABLE_TOOLS]▁[{"type":▁"function",▁"function":▁{"name":▁"t",▁"description":▁"",▁"parameters":▁{"type":▁"object",▁"properties":▁{"g":▁{"type":▁"string"},▁"h":▁{"type":▁"string"}}}}}][/AVAILABLE_TOOLS][INST][IMG][IMG_END]▁a[/INST]▁b[TOOL_RESULTS]▁123456789[TOOL_CONTENT]▁f[/TOOL_RESULTS]' # noqa + ) + + +@pytest.mark.parametrize( + "messages,truncated_text", + [ + # max_tokens is always set to truncate at 15 tokens + pytest.param( + # with the system prompts, only one user message fits, keep the last one + [ + SystemMessage(content="a"), + UserMessage(content="c"), + UserMessage(content="c"), + SystemMessage(content="a"), + UserMessage(content="bbbbbbb"), + ], + "[SYSTEM_PROMPT]a[/SYSTEM_PROMPT][SYSTEM_PROMPT]a[/SYSTEM_PROMPT][INST]bbbbbbb[/INST]", + id="keep_sys_and_last_message", + ), + pytest.param( + # drop the first assistant message - everything else fits + [ + AssistantMessage(content="c"), + UserMessage(content="b"), + UserMessage(content="a"), + UserMessage(content="aaaaaaa"), + ], + "[INST]b[/INST][INST]a[/INST][INST]aaaaaaa[/INST]", + ), + pytest.param( + # the result can start with a non-user message because the input did too + [ + AssistantMessage(content="c"), + AssistantMessage(content="b"), + UserMessage(content="a"), + UserMessage(content="aaaaaaa"), + ], + "b[INST]a[/INST][INST]aaaaaaa[/INST]", + ), + pytest.param( + # drop the first assistant message, then drop user+tool because the go together and both don't fit + [ + AssistantMessage(content="c"), + UserMessage(content="c"), + ToolMessage(content="c", tool_call_id="1234"), + UserMessage(content="a"), + AssistantMessage(content="bbbbbbb"), + ], + "[INST]a[/INST]bbbbbbb", + id="drop_by_chunk_1", + ), + pytest.param( + # drop everything but the last message, because the first chunk (3 messages) is too big + [ + UserMessage(content="c"), + AssistantMessage(content="c"), + AssistantMessage(content="c"), + UserMessage(content="aaaaaaa"), + ], + "[INST]aaaaaaa[/INST]", + id="drop_by_chunk_2", + ), + pytest.param( + [ + SystemMessage(content="a"), + UserMessage(content="c"), + AssistantMessage(content="c"), + UserMessage(content="a"), + AssistantMessage(content="a"), + SystemMessage(content="b"), + UserMessage(content="a"), + ], + "[SYSTEM_PROMPT]a[/SYSTEM_PROMPT][INST]a[/INST]a[SYSTEM_PROMPT]b[/SYSTEM_PROMPT][INST]a[/INST]", + id="full_convo", + ), + ], +) +def test_truncation(tekkenizer: InstructTokenizerV7, messages: List[ChatMessage], truncated_text: str) -> None: + tokenized = tekkenizer.encode_instruct(InstructRequest(messages=messages, truncate_at_max_tokens=15)) + assert tokenized.text == truncated_text, f"{tokenized.text} != {truncated_text}" + + +@pytest.mark.parametrize( + "messages", + [ + [ + # system prompt doesn't fit + SystemMessage(content="a" * 10), + ], + [ + # last user msg doesn't fit + UserMessage(content="a" * 10), + ], + ], +) +def test_truncation_failed(tekkenizer: InstructTokenizerV7, messages: List[ChatMessage]) -> None: + with pytest.raises(TokenizerException): + tekkenizer.encode_instruct(InstructRequest(messages=messages, truncate_at_max_tokens=9))