From 1650218315b0599803aa1fd22b28ea419114a04a Mon Sep 17 00:00:00 2001 From: cgli Date: Thu, 5 Dec 2024 18:51:46 +0800 Subject: [PATCH] =?UTF-8?q?Jinja2=E6=A8=A1=E6=9D=BF=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=8F=96=E6=A8=A1=E8=BF=90=E7=AE=97=E5=92=8C=E6=95=B0=E7=BB=84?= =?UTF-8?q?=E5=88=87=E7=89=87=EF=BC=8C=E6=94=AF=E6=8C=81=E7=9B=B4=E6=8E=A5?= =?UTF-8?q?=E5=8A=A0=E8=BD=BDcodellama-7b-instruct=E7=9A=84HF=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/models.md | 4 ++-- include/template.h | 8 +++++--- src/model.cpp | 5 ++++- src/template.cpp | 51 ++++++++++++++++++++++++++++++++++++++-------- 4 files changed, 54 insertions(+), 14 deletions(-) diff --git a/docs/models.md b/docs/models.md index 8860da77..cb60ef78 100644 --- a/docs/models.md +++ b/docs/models.md @@ -99,8 +99,8 @@ |-----------------: |------------|------------|------------| | meta-llama/Llama-2-7b-chat-hf | [✔](llama_cookbook.md#llama2-chat) | [✔](llama_cookbook.md#llama2-chat) | | | meta-llama/Llama-2-13b-chat-hf | [✔](llama_cookbook.md#llama2-chat) | [✔](llama_cookbook.md#llama2-chat) | | -| codellama/CodeLlama-7b-Instruct-hf | [✔](llama_cookbook.md#llama2-chat) | [✔](llama_cookbook.md#llama2-chat) | | -| codellama/CodeLlama-13b-Instruct-hf | [✔](llama_cookbook.md#llama2-chat) | [✔](llama_cookbook.md#llama2-chat) | | +| codellama/CodeLlama-7b-Instruct-hf | [✔](llama_cookbook.md#llama2-chat) | [✔](llama_cookbook.md#llama2-chat) | ✔ | +| codellama/CodeLlama-13b-Instruct-hf | [✔](llama_cookbook.md#llama2-chat) | [✔](llama_cookbook.md#llama2-chat) | ✔ | | xverse/XVERSE-13B-Chat | [✔](llama_cookbook.md#xverse) | [✔](llama_cookbook.md#xverse) | | | xverse/XVERSE-7B-Chat | [✔](llama_cookbook.md#xverse) | [✔](llama_cookbook.md#xverse) | | | | | | | diff --git a/include/template.h b/include/template.h index e6eb6d76..8f03f11e 100644 --- a/include/template.h +++ b/include/template.h @@ -53,9 +53,9 @@ namespace fastllm { JinjaTokenLMB, JinjaTokenRMB, JinjaTokenLSB, JinjaTokenRSB, JinjaTokenSet, JinjaTokenFor, JinjaTokenEndFor, JinjaTokenIf, JinjaTokenElse, JinjaTokenElseIf, JinjaTokenEndif, JinjaTokenIn, - JinjaTokenAssign, JinjaTokenNotEqual, JinjaTokenEqual, JinjaTokenAdd, JinjaTokenSub, JinjaTokenMul, JinjaTokenDiv, + JinjaTokenAssign, JinjaTokenNotEqual, JinjaTokenEqual, JinjaTokenAdd, JinjaTokenSub, JinjaTokenMul, JinjaTokenDiv, JinjaTokenMod, JinjaTokenNot, JinjaTokenAnd, JinjaTokenOr, - JinjaTokenFliter, JinjaTokenNamespace + JinjaTokenFilter, JinjaTokenNamespace, JinjaTokenSlice }; JinjaToKenType type; @@ -74,7 +74,9 @@ namespace fastllm { {'-', JinjaToken::JinjaToKenType::JinjaTokenSub}, {'*', JinjaToken::JinjaToKenType::JinjaTokenMul}, {'/', JinjaToken::JinjaToKenType::JinjaTokenDiv}, - {'|', JinjaToken::JinjaToKenType::JinjaTokenFliter} + {'%', JinjaToken::JinjaToKenType::JinjaTokenMod}, + {'|', JinjaToken::JinjaToKenType::JinjaTokenFilter}, + {':', JinjaToken::JinjaToKenType::JinjaTokenSlice} }; static std::map escapeChars = { diff --git a/src/model.cpp b/src/model.cpp index c118ccee..68314906 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -479,10 +479,13 @@ namespace fastllm { if (tokenizerClass == "PreTrainedTokenizerFast" || tokenizerClass == "LlamaTokenizerFast" || tokenizerClass == "Qwen2Tokenizer" || tokenizerClass == "BloomTokenizer" - || tokenizerClass == "LlamaTokenizer" + || tokenizerClass == "LlamaTokenizer" || tokenizerClass == "CodeLlamaTokenizer" || tokenizerClass == "MiniCPMTokenizer") { // PreTrainedTokenizerFast std::string tokenizerFile = path + "tokenizer.json"; + if (!fastllm::FileExists(tokenizerFile)) { + ErrorInFastLLM("Model with a supported tokenizer_class: " + tokenizerClass + ",but has no \"tokenizer.json\"!"); + } auto tokenizer = json11::Json::parse(ReadAllFile(tokenizerFile), error); for (auto &it : tokenizer["model"]["vocab"].object_items()) { model->weight.AddTokenizerWord(it.first, it.second.int_value(), 1.0f); diff --git a/src/template.cpp b/src/template.cpp index 3ba85db7..ddacaa6f 100644 --- a/src/template.cpp +++ b/src/template.cpp @@ -221,6 +221,12 @@ namespace fastllm { } else if (op == JinjaToken::JinjaTokenAdd) { if (a.type == JinjaVar::JinjaString && b.type == JinjaVar::JinjaString) { return a.stringValue + b.stringValue; + } else if (a.type == JinjaVar::JinjaInt && b.type == JinjaVar::JinjaInt) { + return a.intValue + b.intValue; + } + } else if (op == JinjaToken::JinjaTokenMod) { + if (a.type == JinjaVar::JinjaInt && b.type == JinjaVar::JinjaInt) { + return a.intValue % b.intValue; } } else if (op == JinjaToken::JinjaTokenIn) { return b.dictValue.find(a.stringValue) != b.dictValue.end(); @@ -264,16 +270,18 @@ namespace fastllm { return -2; } else if (type == JinjaToken::JinjaTokenNot) { return -1; - } else if (type == JinjaToken::JinjaTokenEqual || type == JinjaToken::JinjaTokenNotEqual) { + } else if (type == JinjaToken::JinjaTokenEqual || type == JinjaToken::JinjaTokenNotEqual || type == JinjaToken::JinjaTokenIn) { return 0; } else if (type == JinjaToken::JinjaTokenAdd || type == JinjaToken::JinjaTokenSub) { return 1; - } else if (type == JinjaToken::JinjaTokenMul || type == JinjaToken::JinjaTokenDiv) { + } else if (type == JinjaToken::JinjaTokenMul || type == JinjaToken::JinjaTokenDiv || type == JinjaToken::JinjaTokenMod) { return 2; - } else if (type == JinjaToken::JinjaTokenFliter) { + } else if (type == JinjaToken::JinjaTokenFilter) { return 3; } else if (type == JinjaToken::JinjaTokenDOT) { return 4; + } else if (type == JinjaToken::JinjaTokenSlice) { + return 5; } else if (type == JinjaToken::JinjaTokenLSB || type == JinjaToken::JinjaTokenLMB) { return -5; } else { @@ -298,7 +306,15 @@ namespace fastllm { if (trimNext) part.erase(part.find_last_not_of(" \n\r\t") + 1); blocks.push_back(JinjaBlock(part)); - blocks.push_back(temp.substr(i, curEnd + 2 - i)); + // 处理切片语法糖 + part = temp.substr(i, curEnd + 2 - i); + size_t slicepos = part.find("[:"); + if (slicepos != std::string::npos) + part.replace(slicepos, 2, "[0:"); + slicepos = part.find(":]"); + if (slicepos != std::string::npos) + part.replace(slicepos, 2, ":0]"); + blocks.push_back(JinjaBlock(part)); trimNext = (temp[curEnd - 1] == '-'); pos = curEnd + 2; i = curEnd + 1; @@ -342,20 +358,23 @@ namespace fastllm { ops.pop_back(); } AssertInFastLLM(ops.size() > 0 && ops.back().type == JinjaToken::JinjaTokenLMB, "Error: barckets doesn't match."); - suffixExp.push_back(tokens[i]); + if (suffixExp.back().type != JinjaToken::JinjaTokenSlice) + suffixExp.push_back(tokens[i]); ops.pop_back(); } else if (tokens[i].type == JinjaToken::JinjaTokenDOT || tokens[i].type == JinjaToken::JinjaTokenAdd || tokens[i].type == JinjaToken::JinjaTokenSub || tokens[i].type == JinjaToken::JinjaTokenMul || tokens[i].type == JinjaToken::JinjaTokenDiv || + tokens[i].type == JinjaToken::JinjaTokenMod || tokens[i].type == JinjaToken::JinjaTokenEqual || tokens[i].type == JinjaToken::JinjaTokenNotEqual || + tokens[i].type == JinjaToken::JinjaTokenSlice || tokens[i].type == JinjaToken::JinjaTokenIn || tokens[i].type == JinjaToken::JinjaTokenAnd || tokens[i].type == JinjaToken::JinjaTokenOr || tokens[i].type == JinjaToken::JinjaTokenNot || - tokens[i].type == JinjaToken::JinjaTokenFliter) { + tokens[i].type == JinjaToken::JinjaTokenFilter) { while (ops.size() > 0 && GetOpLevel(ops.back().type) > GetOpLevel(tokens[i].type)) { suffixExp.push_back(ops.back()); ops.pop_back(); @@ -412,7 +431,7 @@ namespace fastllm { vars.pop_back(); vars.pop_back(); vars.push_back(JinjaVar({{a.stringValue, b}})); - } else if (it.type == JinjaToken::JinjaTokenFliter) { + } else if (it.type == JinjaToken::JinjaTokenFilter) { AssertInFastLLM(vars.size() > 1, "Jinja Error: expression error."); JinjaVar a = vars[vars.size() - 2], b = vars.back(); if (a.type == JinjaVar::JinjaNone) { @@ -437,6 +456,7 @@ namespace fastllm { it.type == JinjaToken::JinjaTokenSub || it.type == JinjaToken::JinjaTokenMul || it.type == JinjaToken::JinjaTokenDiv || + it.type == JinjaToken::JinjaTokenMod || it.type == JinjaToken::JinjaTokenAssign || it.type == JinjaToken::JinjaTokenEqual || it.type == JinjaToken::JinjaTokenNotEqual || @@ -445,7 +465,7 @@ namespace fastllm { it.type == JinjaToken::JinjaTokenOr) { AssertInFastLLM(vars.size() > 1, "Jinja Error: expression error."); JinjaVar a = vars[vars.size() - 2], b = vars.back(); - if (a.type == JinjaVar::JinjaNone) { + if (a.type == JinjaVar::JinjaNone && it.type != JinjaToken::JinjaTokenIn) { a = local[a]; } if (b.type == JinjaVar::JinjaNone) { @@ -454,6 +474,21 @@ namespace fastllm { vars.pop_back(); vars.pop_back(); vars.push_back(JinjaBinaryOp(a, b, it.type)); + } else if (it.type == JinjaToken::JinjaTokenSlice) { + AssertInFastLLM(vars.size() >= 3, "Jinja Error: expression error."); + JinjaVar a = vars[vars.size() - 3], b = vars[vars.size() - 2], c = vars.back(); + if (a.type == JinjaVar::JinjaNone) { + a = local[a]; + } + AssertInFastLLM(a.type == JinjaVar::JinjaArray && b.type == JinjaVar::JinjaInt && c.type == JinjaVar::JinjaInt, + "Jinja Error: slice expression error."); + vars.pop_back(); + vars.pop_back(); + vars.pop_back(); + if (c.intValue <= 0) + c.intValue += a.arrayValue.size(); + std::vector subArray(a.arrayValue.begin() + b.intValue, a.arrayValue.begin() + c.intValue); + vars.push_back(JinjaVar(subArray)); } }