From e82ab3d2e6be48c799b0eab097399c92403973a6 Mon Sep 17 00:00:00 2001 From: Ryan Hill <38674843+RyanUnderhill@users.noreply.github.com> Date: Thu, 2 May 2024 23:02:14 -0700 Subject: [PATCH] Merge tokenizer library invalid UTF-8 fix (#390) Port over a fix from the onnxruntime-extensions tokenizer library to fix an invalid UTF-8 issue. --- src/tokenizer/token_bpe.cc | 33 ++++++++++++---- src/tokenizer/token_bpe.h | 1 + src/tokenizer/tokenizer.cc | 6 +-- src/tokenizer/utils/unescape.cc | 69 ++++++++++++++++++++++++--------- 4 files changed, 80 insertions(+), 29 deletions(-) diff --git a/src/tokenizer/token_bpe.cc b/src/tokenizer/token_bpe.cc index 93c897eea..80ac9d5bf 100644 --- a/src/tokenizer/token_bpe.cc +++ b/src/tokenizer/token_bpe.cc @@ -237,15 +237,17 @@ std::vector BPETokenizer::Encode(std::string_view sv_input, int64_ text = text.strip() */ std::u32string str = RemoveConsecutiveSpaces(input); - if (IsUnicodeSpace(str.front())) { - str.erase(str.begin()); - } - if (IsUnicodeSpace(str.back())) { - str.pop_back(); + if (!str.empty()) { + if (IsUnicodeSpace(str.front())) { + str.erase(str.begin()); + } + if (IsUnicodeSpace(str.back())) { + str.pop_back(); + } + // remove newlines as CLIP ignores them (treats them as whitespace which is then cleaned) + str.erase(std::remove(str.begin(), str.end(), U'\n'), str.end()); + str.erase(std::remove(str.begin(), str.end(), U'\r'), str.end()); } - // remove newlines as CLIP ignores them (treats them as whitespace which is then cleaned) - str.erase(std::remove(str.begin(), str.end(), U'\n'), str.end()); - str.erase(std::remove(str.begin(), str.end(), U'\r'), str.end()); input = str; } @@ -592,6 +594,21 @@ TfmStatus BPETokenizer::Id2Token(tfmTokenId_t id, std::string& token, DecoderSta token.push_back(' '); } } // end case of whitespace_token_ + + bpe_state->incomplete_utf8_ += token; + token.clear(); + std::string& s_utf8 = bpe_state->incomplete_utf8_; + size_t utf8_len = 1; + size_t utf8_all_len = 0; + for (size_t i = 0; i < s_utf8.size(); i += utf8_len) { + utf8_len = UTF8Len(s_utf8[i]); + if (utf8_len <= s_utf8.size() - i) { + utf8_all_len += utf8_len; + auto _t = s_utf8.substr(i, utf8_len); + token += ValidateUTF8(_t) ? _t : ""; + } + } + s_utf8 = s_utf8.substr(utf8_all_len); } return status; diff --git a/src/tokenizer/token_bpe.h b/src/tokenizer/token_bpe.h index ed5f1f23c..2327b3a60 100644 --- a/src/tokenizer/token_bpe.h +++ b/src/tokenizer/token_bpe.h @@ -28,6 +28,7 @@ class BPETokenizer : public TokenizerImpl { BPEDeocerState() = default; ~BPEDeocerState() override = default; bool f_special_last; + std::string incomplete_utf8_; }; public: diff --git a/src/tokenizer/tokenizer.cc b/src/tokenizer/tokenizer.cc index b2a0622e7..251595856 100644 --- a/src/tokenizer/tokenizer.cc +++ b/src/tokenizer/tokenizer.cc @@ -30,10 +30,10 @@ TfmStatus CreateBPETokenizer(const std::string& tokenizer_path, if (type.empty()) { if (BPETokenizer::IsSupportedModel(GetModelName(token_cfg->tokenizer_class_))) { type = "BPE"; - } else if (std::filesystem::exists(tokenizer_path + "/tokenizer.model")) { + } /* else if (std::filesystem::exists(tokenizer_path + "/tokenizer.model")) { // if 'tokenizer.model exists in the tokenizer_path, then it is a sentencepiece model type = "SPM"; - } else { + } */ else { status = TfmStatus(kTfmErrorInvalidArgument, "Cannot determine the tokenizer type from tokenizer_path argument"); } } @@ -43,7 +43,7 @@ TfmStatus CreateBPETokenizer(const std::string& tokenizer_path, } /* else if (type == "SPM") { token_ptr = std::make_unique(); } */ else { - status = TfmStatus(kTfmErrorInvalidArgument, "Unknown tokenizer_type, (BPE, SPM, RKWV) are supported."); + status = TfmStatus(kTfmErrorInvalidArgument, "Unknown tokenizer_type, (BPE, RKWV) are supported."); } if (status.ok()) { diff --git a/src/tokenizer/utils/unescape.cc b/src/tokenizer/utils/unescape.cc index f42e962f9..f94a1f192 100644 --- a/src/tokenizer/utils/unescape.cc +++ b/src/tokenizer/utils/unescape.cc @@ -41,27 +41,60 @@ std::string EncodeUTF8Char(char32_t utf8_char) { return {utf8_buf}; } -bool ValidateUTF8(const std::string& data) { - int cnt = 0; - for (size_t i = 0; i < data.size(); i++) { - int x = data[i]; - if (!cnt) { - if ((x >> 5) == 0b110) { - cnt = 1; - } else if ((x >> 4) == 0b1110) { - cnt = 2; - } else if ((x >> 3) == 0b11110) { - cnt = 3; - } else if ((x >> 7) != 0) { + bool ValidateUTF8(const std::string& data) { + const unsigned char* s = reinterpret_cast(data.c_str()); + const unsigned char* s_end = s + data.size(); + if (*s_end != '\0') + return false; + + while (*s) { + if (*s < 0x80) + /* 0xxxxxxx */ + s++; + else if ((s[0] & 0xe0) == 0xc0) { + /* 110XXXXx 10xxxxxx */ + if (s + 1 >= s_end) { + return false; + } + if ((s[1] & 0xc0) != 0x80 || + (s[0] & 0xfe) == 0xc0) /* overlong? */ + return false; + else + s += 2; + } else if ((s[0] & 0xf0) == 0xe0) { + /* 1110XXXX 10Xxxxxx 10xxxxxx */ + if (s + 2 >= s_end) { + return false; + } + if ((s[1] & 0xc0) != 0x80 || + (s[2] & 0xc0) != 0x80 || + (s[0] == 0xe0 && (s[1] & 0xe0) == 0x80) || /* overlong? */ + (s[0] == 0xed && (s[1] & 0xe0) == 0xa0) || /* surrogate? */ + (s[0] == 0xef && s[1] == 0xbf && + (s[2] & 0xfe) == 0xbe)) /* U+FFFE or U+FFFF? */ + return false; + else + s += 3; + } else if ((s[0] & 0xf8) == 0xf0) { + /* 11110XXX 10XXxxxx 10xxxxxx 10xxxxxx */ + if (s + 3 >= s_end) { + return false; + } + if ((s[1] & 0xc0) != 0x80 || + (s[2] & 0xc0) != 0x80 || + (s[3] & 0xc0) != 0x80 || + (s[0] == 0xf0 && (s[1] & 0xf0) == 0x80) || /* overlong? */ + (s[0] == 0xf4 && s[1] > 0x8f) || s[0] > 0xf4) /* > U+10FFFF? */ + return false; + else + s += 4; + } else return false; - } - } else { - if ((x >> 6) != 0b10) return false; - cnt--; } + + return true; } - return cnt == 0; -} + bool IsDigit(char c) { return c >= '0' && c <= '9'; } bool IsHexDigit(char c) { return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F'); }