Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…sions into sayanshaw/phi4-regex
  • Loading branch information
Sayan Shaw committed Jan 16, 2025
2 parents c8002f0 + f8f3ae9 commit c3dc4c1
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .pipelines/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ stages:
steps:
- script: |
cd $(Build.BinariesDirectory)
git clone https://github.com/emscripten-core/emsdk
git clone https://github.com/emscripten-core/emsdk --depth 1 --branch 3.1.74
emsdk/emsdk install latest
emsdk/emsdk activate latest
displayName: Setup emscripten pipeline
Expand Down
12 changes: 9 additions & 3 deletions operators/tokenizer/tokenizer_jsconfig.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ enum class TokenType {
};

constexpr std::pair<const char*, TokenType> kTokenizerDict[] = {
{"PreTrainedTokenizerFast", TokenType::kBPE},
{"PreTrainedTokenizer", TokenType::kBPE},
{"CLIPTokenizer", TokenType::kBPE},
{"WhisperTokenizer", TokenType::kBPE},
{"GemmaTokenizer", TokenType::kBPE},
Expand Down Expand Up @@ -256,10 +256,16 @@ class TokenJsonConfig final {
}

static TokenType GetTokenType(const std::string& tok) {
static const std::unordered_map<std::string, TokenType> dict {
static const std::unordered_map<std::string_view, TokenType> dict {
std::begin(kTokenizerDict), std::end(kTokenizerDict) };

auto iter = dict.find(tok);
std::string_view tok_class(tok);
auto pos = tok_class.find("Fast");
if (pos != std::string_view::npos && pos + 4 == tok_class.size()) {
tok_class.remove_suffix(4);
}

auto iter = dict.find(tok_class);
return iter == dict.end() ? TokenType::kUnknown : iter->second;
}

Expand Down
2 changes: 1 addition & 1 deletion operators/tokenizer/tokenizer_op_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class JsonTokenizerOpKernel {
} else if (type == TokenType::kBPE) {
tokenizer_ = std::make_unique<JsonFastTokenizer>();
} else {
return OrtxStatus(kOrtxErrorCorruptData, "Unknown tokenizer type");
return OrtxStatus(kOrtxErrorCorruptData, "Unknown tokenizer type" + cfg.tokenizer_class_);
}

return std::visit([&](auto& ptr) { return ptr->Load(cfg); }, tokenizer_);
Expand Down
2 changes: 1 addition & 1 deletion shared/api/tokenizer_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ OrtxStatus TokenizerImpl::LoadTokenizer(const OrtxTokenizerBlob* blob) {
return status;
}

return OrtxStatus(kOrtxErrorNotImplemented, "Unsupported tokenizer class");
return OrtxStatus(kOrtxErrorNotImplemented, "Unsupported tokenizer class: " + tok_config_->tokenizer_class_);
}

OrtxStatus TokenizerImpl::Load(const OrtxTokenizerBlob& blob) {
Expand Down
9 changes: 9 additions & 0 deletions test/test_pp_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,15 @@ def test_Qwen_QVQ_tokenizer(self):
ortx_inputs = tokenizer.tokenize(test_sentence)
np.testing.assert_array_equal(ortx_inputs, inputs)

def test_Phi4_tokenizer(self):
model_id = "/g/phi-x-12202024"
test_sentence = [self.tokenizer_test_sentence]
hf_enc = AutoTokenizer.from_pretrained(model_id)
inputs = hf_enc(test_sentence)["input_ids"]
tokenizer = pp_api.Tokenizer(model_id)
ortx_inputs = tokenizer.tokenize(test_sentence)
np.testing.assert_array_equal(ortx_inputs, inputs)


if __name__ == "__main__":
unittest.main()

0 comments on commit c3dc4c1

Please sign in to comment.