Skip to content

Commit

Permalink
Add Skip Tokens Node (#264)
Browse files Browse the repository at this point in the history
* [WIP] Add Skip Tokens Node

* Add Skip Tokens Node

* Ruff format

* Fix re2 segfault

* Fix re2 pattern issue

* Fix merges is bytes

* Fixes for transformers 4.45
  • Loading branch information
apaniukov authored Sep 27, 2024
1 parent 81c067c commit 0ac89d3
Show file tree
Hide file tree
Showing 27 changed files with 17,369 additions and 17,709 deletions.
262 changes: 131 additions & 131 deletions README.md

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,10 @@ def main(
if per_layer_stats:
config[properties.enable_profiling()] = True

start_compile = perf_counter()
ov_tokenizer = compile_model(convert_tokenizer(hf_tokenizer), "CPU", config)
end_compile = perf_counter()
print(f"Time to compile tokenizer model: {end_compile - start_compile}s")

dataset = sample_texts(dataset, batch * num_pairs)
result_df = benchmark_tokenizers(ov_tokenizer, hf_tokenizer, dataset, per_layer_stats, batch)
Expand Down
3 changes: 2 additions & 1 deletion python/openvino_tokenizers/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from openvino_tokenizers import convert_tokenizer
from openvino_tokenizers.constants import UTF8ReplaceMode


class StringToTypeAction(Action):
string_to_type_dict = {
"i32": Type.i32,
Expand Down Expand Up @@ -223,7 +224,7 @@ def get_parser() -> ArgumentParser:
help=(
"If specified then resulting strings during decoding are checked if sequence of bytes is a valid UTF-8 sequence. "
f"If mode is '{UTF8ReplaceMode.REPLACE}' then invalid characters are replaced with �, if mode is '{UTF8ReplaceMode.IGNORE}' then invalid character are skipped."
)
),
)
return parser

Expand Down
9 changes: 5 additions & 4 deletions python/openvino_tokenizers/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from enum import Enum


ATTENTION_MASK_INPUT_NAME = "attention_mask"
TOKEN_IDS_INPUT_NAME = "input_ids"
Expand Down Expand Up @@ -33,10 +35,9 @@
MIN_CACHE_CAPACITY = 20_000
VOCAB_SIZE_CACHE_PROPORTION = 0.2

from enum import Enum
class UTF8ReplaceMode(Enum):
IGNORE: str = 'ignore'
REPLACE: str = 'replace'
IGNORE: str = "ignore"
REPLACE: str = "replace"

def __str__(self):
return self.value
3 changes: 2 additions & 1 deletion python/openvino_tokenizers/convert_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from openvino.runtime import Model, Type
from openvino.runtime.exceptions import OVTypeError

from openvino_tokenizers.utils import change_inputs_type, change_outputs_type, update_rt_info
from openvino_tokenizers.constants import UTF8ReplaceMode
from openvino_tokenizers.utils import change_inputs_type, change_outputs_type, update_rt_info


logger = logging.getLogger(__name__)

Expand Down
29 changes: 16 additions & 13 deletions python/openvino_tokenizers/hf_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
TOKEN_IDS_INPUT_NAME,
TOKEN_TYPE_IDS_INPUT_NAME,
TOKENIZER_NAME,
UTF8ReplaceMode
UTF8ReplaceMode,
)
from .tokenizer_pipeline import (
AddToken,
Expand All @@ -43,7 +43,6 @@
CombineSegmentsStep,
DecodingStep,
FuseStep,
UTF8ValidateStep,
NMTNormalizationStep,
NormalizationStep,
NormalizeUnicode,
Expand All @@ -53,9 +52,11 @@
RegexNormalizationStep,
RegexSplitStep,
Sequence,
SpecialTokensSplit,
StripStringStep,
TokenizerPipeline,
TruncationStep,
UTF8ValidateStep,
VocabDecoderStep,
WhitespaceSplitStep,
WordPieceTokenizationStep,
Expand Down Expand Up @@ -120,9 +121,6 @@ def parse_byte_level_pretokenization_step(

# regex is used by default, but it does not appear in config yet
if pretokenizer_dict.get("use_regex", True):
# re2 does not support negative lookahead, so there is two steps replicate the behaviour
# this WA causes segfault for CLIP tokenizer
# steps.append(RegexSplitStep.add_whitespace_to_the_next_word())
steps.append(RegexSplitStep.byte_level_splitter())

steps.append(BytesToCharsStep())
Expand Down Expand Up @@ -174,11 +172,11 @@ def parse(
clean_up_tokenization_spaces: Optional[bool] = None,
use_max_padding: bool = False,
utf8_replace_mode: Optional[UTF8ReplaceMode] = None,

) -> TokenizerPipeline:
self.number_of_inputs = self.number_of_inputs if number_of_inputs is None else number_of_inputs
self.pipeline.number_of_inputs = self.number_of_inputs
for add_steps in [
self.special_tokens_split,
self.normalization,
self.pre_tokenization,
self.tokenization_model,
Expand All @@ -194,6 +192,9 @@ def parse(

return self.pipeline

def special_tokens_split(self) -> None:
self.pipeline.add_steps(SpecialTokensSplit.from_hf_tokenizer(self.original_tokenizer))

normalizers_map: Dict[
str,
Callable[[Dict[str, Any]], Union[NormalizationStep, List[NormalizationStep]]],
Expand Down Expand Up @@ -412,19 +413,20 @@ def decoding(
self.pipeline.add_steps(CharsToBytesStep())
else:
self.pipeline.add_steps(FuseStep())

if utf8_replace_mode is not None:
self.pipeline.add_steps(UTF8ValidateStep(mode=utf8_replace_mode))

if clean_up_tokenization_spaces is None:
clean_up_tokenization_spaces = self.original_tokenizer.clean_up_tokenization_spaces

if suffix := self.tokenizer_json["model"].get("end_of_word_suffix"):
self.pipeline.add_steps(RegexDecodingStep.replace_end_of_word_suffix(suffix=suffix))
self.pipeline.add_steps(RegexDecodingStep.rstrip_space())

if prefix := self.tokenizer_json["model"].get("continuing_subword_prefix"):
self.pipeline.add_steps(RegexDecodingStep.replace_continuing_subword_prefix(prefix=prefix))

if clean_up_tokenization_spaces is None:
clean_up_tokenization_spaces = self.original_tokenizer.clean_up_tokenization_spaces

if clean_up_tokenization_spaces and self.pipeline.decoding_steps:
self.pipeline.add_steps(RegexDecodingStep.clean_up_tokenization_spaces())
return
Expand Down Expand Up @@ -1036,7 +1038,7 @@ def get_sp_detokenizer(
if clean_up_tokenization_spaces:
detokenizer = RegexDecodingStep.clean_up_tokenization_spaces().get_ov_subgraph(detokenizer)

if utf8_replace_mode is not None:
if utf8_replace_mode is not None:
replace_mode = True if utf8_replace_mode is UTF8ReplaceMode.REPLACE else False
UTF8ValidateStep(mode=replace_mode).get_ov_subgraph(detokenizer)

Expand Down Expand Up @@ -1085,6 +1087,7 @@ def convert_tiktoken_model_tokenizer(
reference_vocab = getattr(hf_tokenizer, "get_vocab", lambda: None)()
pipeline.add_steps(
[
SpecialTokensSplit.from_hf_tokenizer(hf_tokenizer),
NormalizeUnicode("NFC"),
RegexSplitStep(split_pattern, behaviour="contiguous"),
BPETokenizationStep.from_tiktoken_encoding(encoding, reference_vocab=reference_vocab),
Expand All @@ -1100,7 +1103,7 @@ def convert_tiktoken_model_tokenizer(
)

# (chat)GLM model adds spaces around <sop> token
decoder_vocab = pipeline[2].vocab
decoder_vocab = pipeline[3].vocab
sop_index = next((idx for idx, token in enumerate(decoder_vocab) if token == "<sop>"), None)
if sop_index is not None:
decoder_vocab[sop_index] = " <sop> "
Expand All @@ -1113,7 +1116,7 @@ def convert_tiktoken_model_tokenizer(
)

if utf8_replace_mode is not None:
pipeline.add_steps(UTF8ValidateStep(mode=utf8_replace_mode)),
(pipeline.add_steps(UTF8ValidateStep(mode=utf8_replace_mode)),)

if clean_up_tokenization_spaces is None:
clean_up_tokenization_spaces = getattr(hf_tokenizer, "clean_up_tokenization_spaces", None)
Expand Down
8 changes: 0 additions & 8 deletions python/openvino_tokenizers/tiktoken_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,6 @@

from tiktoken import Encoding

from .utils import bytes_to_unicode


# https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee
def token_bytes_to_string(b: bytes) -> str:
byte_encoder = bytes_to_unicode()
return "".join(byte_encoder[ord(char)] for char in b.decode("latin-1"))


def bpe(mergeable_ranks: Dict[bytes, int], token: bytes, max_rank: Optional[int] = None) -> List[bytes]:
parts = [bytes([b]) for b in token]
Expand Down
Loading

0 comments on commit 0ac89d3

Please sign in to comment.