Skip to content

Commit

Permalink
Jinja2模板支持取模运算和数组切片,支持直接加载codellama-7b-instruct的HF模型
Browse files Browse the repository at this point in the history
  • Loading branch information
cgli committed Dec 5, 2024
1 parent 3b2220d commit 1650218
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 14 deletions.
4 changes: 2 additions & 2 deletions docs/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@
|-----------------: |------------|------------|------------|
| meta-llama/Llama-2-7b-chat-hf | [](llama_cookbook.md#llama2-chat) | [](llama_cookbook.md#llama2-chat) | |
| meta-llama/Llama-2-13b-chat-hf | [](llama_cookbook.md#llama2-chat) | [](llama_cookbook.md#llama2-chat) | |
| codellama/CodeLlama-7b-Instruct-hf | [](llama_cookbook.md#llama2-chat) | [](llama_cookbook.md#llama2-chat) | |
| codellama/CodeLlama-13b-Instruct-hf | [](llama_cookbook.md#llama2-chat) | [](llama_cookbook.md#llama2-chat) | |
| codellama/CodeLlama-7b-Instruct-hf | [](llama_cookbook.md#llama2-chat) | [](llama_cookbook.md#llama2-chat) | |
| codellama/CodeLlama-13b-Instruct-hf | [](llama_cookbook.md#llama2-chat) | [](llama_cookbook.md#llama2-chat) | |
| xverse/XVERSE-13B-Chat | [](llama_cookbook.md#xverse) | [](llama_cookbook.md#xverse) | |
| xverse/XVERSE-7B-Chat | [](llama_cookbook.md#xverse) | [](llama_cookbook.md#xverse) | |
| | | | |
Expand Down
8 changes: 5 additions & 3 deletions include/template.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ namespace fastllm {
JinjaTokenLMB, JinjaTokenRMB, JinjaTokenLSB, JinjaTokenRSB,
JinjaTokenSet, JinjaTokenFor, JinjaTokenEndFor, JinjaTokenIf, JinjaTokenElse, JinjaTokenElseIf, JinjaTokenEndif,
JinjaTokenIn,
JinjaTokenAssign, JinjaTokenNotEqual, JinjaTokenEqual, JinjaTokenAdd, JinjaTokenSub, JinjaTokenMul, JinjaTokenDiv,
JinjaTokenAssign, JinjaTokenNotEqual, JinjaTokenEqual, JinjaTokenAdd, JinjaTokenSub, JinjaTokenMul, JinjaTokenDiv, JinjaTokenMod,
JinjaTokenNot, JinjaTokenAnd, JinjaTokenOr,
JinjaTokenFliter, JinjaTokenNamespace
JinjaTokenFilter, JinjaTokenNamespace, JinjaTokenSlice
};

JinjaToKenType type;
Expand All @@ -74,7 +74,9 @@ namespace fastllm {
{'-', JinjaToken::JinjaToKenType::JinjaTokenSub},
{'*', JinjaToken::JinjaToKenType::JinjaTokenMul},
{'/', JinjaToken::JinjaToKenType::JinjaTokenDiv},
{'|', JinjaToken::JinjaToKenType::JinjaTokenFliter}
{'%', JinjaToken::JinjaToKenType::JinjaTokenMod},
{'|', JinjaToken::JinjaToKenType::JinjaTokenFilter},
{':', JinjaToken::JinjaToKenType::JinjaTokenSlice}
};

static std::map <char, char> escapeChars = {
Expand Down
5 changes: 4 additions & 1 deletion src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -479,10 +479,13 @@ namespace fastllm {
if (tokenizerClass == "PreTrainedTokenizerFast" || tokenizerClass == "LlamaTokenizerFast"
|| tokenizerClass == "Qwen2Tokenizer"
|| tokenizerClass == "BloomTokenizer"
|| tokenizerClass == "LlamaTokenizer"
|| tokenizerClass == "LlamaTokenizer" || tokenizerClass == "CodeLlamaTokenizer"
|| tokenizerClass == "MiniCPMTokenizer") {
// PreTrainedTokenizerFast
std::string tokenizerFile = path + "tokenizer.json";
if (!fastllm::FileExists(tokenizerFile)) {
ErrorInFastLLM("Model with a supported tokenizer_class: " + tokenizerClass + ",but has no \"tokenizer.json\"!");
}
auto tokenizer = json11::Json::parse(ReadAllFile(tokenizerFile), error);
for (auto &it : tokenizer["model"]["vocab"].object_items()) {
model->weight.AddTokenizerWord(it.first, it.second.int_value(), 1.0f);
Expand Down
51 changes: 43 additions & 8 deletions src/template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,12 @@ namespace fastllm {
} else if (op == JinjaToken::JinjaTokenAdd) {
if (a.type == JinjaVar::JinjaString && b.type == JinjaVar::JinjaString) {
return a.stringValue + b.stringValue;
} else if (a.type == JinjaVar::JinjaInt && b.type == JinjaVar::JinjaInt) {
return a.intValue + b.intValue;
}
} else if (op == JinjaToken::JinjaTokenMod) {
if (a.type == JinjaVar::JinjaInt && b.type == JinjaVar::JinjaInt) {
return a.intValue % b.intValue;
}
} else if (op == JinjaToken::JinjaTokenIn) {
return b.dictValue.find(a.stringValue) != b.dictValue.end();
Expand Down Expand Up @@ -264,16 +270,18 @@ namespace fastllm {
return -2;
} else if (type == JinjaToken::JinjaTokenNot) {
return -1;
} else if (type == JinjaToken::JinjaTokenEqual || type == JinjaToken::JinjaTokenNotEqual) {
} else if (type == JinjaToken::JinjaTokenEqual || type == JinjaToken::JinjaTokenNotEqual || type == JinjaToken::JinjaTokenIn) {
return 0;
} else if (type == JinjaToken::JinjaTokenAdd || type == JinjaToken::JinjaTokenSub) {
return 1;
} else if (type == JinjaToken::JinjaTokenMul || type == JinjaToken::JinjaTokenDiv) {
} else if (type == JinjaToken::JinjaTokenMul || type == JinjaToken::JinjaTokenDiv || type == JinjaToken::JinjaTokenMod) {
return 2;
} else if (type == JinjaToken::JinjaTokenFliter) {
} else if (type == JinjaToken::JinjaTokenFilter) {
return 3;
} else if (type == JinjaToken::JinjaTokenDOT) {
return 4;
} else if (type == JinjaToken::JinjaTokenSlice) {
return 5;
} else if (type == JinjaToken::JinjaTokenLSB || type == JinjaToken::JinjaTokenLMB) {
return -5;
} else {
Expand All @@ -298,7 +306,15 @@ namespace fastllm {
if (trimNext)
part.erase(part.find_last_not_of(" \n\r\t") + 1);
blocks.push_back(JinjaBlock(part));
blocks.push_back(temp.substr(i, curEnd + 2 - i));
// 处理切片语法糖
part = temp.substr(i, curEnd + 2 - i);
size_t slicepos = part.find("[:");
if (slicepos != std::string::npos)
part.replace(slicepos, 2, "[0:");
slicepos = part.find(":]");
if (slicepos != std::string::npos)
part.replace(slicepos, 2, ":0]");
blocks.push_back(JinjaBlock(part));
trimNext = (temp[curEnd - 1] == '-');
pos = curEnd + 2;
i = curEnd + 1;
Expand Down Expand Up @@ -342,20 +358,23 @@ namespace fastllm {
ops.pop_back();
}
AssertInFastLLM(ops.size() > 0 && ops.back().type == JinjaToken::JinjaTokenLMB, "Error: barckets doesn't match.");
suffixExp.push_back(tokens[i]);
if (suffixExp.back().type != JinjaToken::JinjaTokenSlice)
suffixExp.push_back(tokens[i]);
ops.pop_back();
} else if (tokens[i].type == JinjaToken::JinjaTokenDOT ||
tokens[i].type == JinjaToken::JinjaTokenAdd ||
tokens[i].type == JinjaToken::JinjaTokenSub ||
tokens[i].type == JinjaToken::JinjaTokenMul ||
tokens[i].type == JinjaToken::JinjaTokenDiv ||
tokens[i].type == JinjaToken::JinjaTokenMod ||
tokens[i].type == JinjaToken::JinjaTokenEqual ||
tokens[i].type == JinjaToken::JinjaTokenNotEqual ||
tokens[i].type == JinjaToken::JinjaTokenSlice ||
tokens[i].type == JinjaToken::JinjaTokenIn ||
tokens[i].type == JinjaToken::JinjaTokenAnd ||
tokens[i].type == JinjaToken::JinjaTokenOr ||
tokens[i].type == JinjaToken::JinjaTokenNot ||
tokens[i].type == JinjaToken::JinjaTokenFliter) {
tokens[i].type == JinjaToken::JinjaTokenFilter) {
while (ops.size() > 0 && GetOpLevel(ops.back().type) > GetOpLevel(tokens[i].type)) {
suffixExp.push_back(ops.back());
ops.pop_back();
Expand Down Expand Up @@ -412,7 +431,7 @@ namespace fastllm {
vars.pop_back();
vars.pop_back();
vars.push_back(JinjaVar({{a.stringValue, b}}));
} else if (it.type == JinjaToken::JinjaTokenFliter) {
} else if (it.type == JinjaToken::JinjaTokenFilter) {
AssertInFastLLM(vars.size() > 1, "Jinja Error: expression error.");
JinjaVar a = vars[vars.size() - 2], b = vars.back();
if (a.type == JinjaVar::JinjaNone) {
Expand All @@ -437,6 +456,7 @@ namespace fastllm {
it.type == JinjaToken::JinjaTokenSub ||
it.type == JinjaToken::JinjaTokenMul ||
it.type == JinjaToken::JinjaTokenDiv ||
it.type == JinjaToken::JinjaTokenMod ||
it.type == JinjaToken::JinjaTokenAssign ||
it.type == JinjaToken::JinjaTokenEqual ||
it.type == JinjaToken::JinjaTokenNotEqual ||
Expand All @@ -445,7 +465,7 @@ namespace fastllm {
it.type == JinjaToken::JinjaTokenOr) {
AssertInFastLLM(vars.size() > 1, "Jinja Error: expression error.");
JinjaVar a = vars[vars.size() - 2], b = vars.back();
if (a.type == JinjaVar::JinjaNone) {
if (a.type == JinjaVar::JinjaNone && it.type != JinjaToken::JinjaTokenIn) {
a = local[a];
}
if (b.type == JinjaVar::JinjaNone) {
Expand All @@ -454,6 +474,21 @@ namespace fastllm {
vars.pop_back();
vars.pop_back();
vars.push_back(JinjaBinaryOp(a, b, it.type));
} else if (it.type == JinjaToken::JinjaTokenSlice) {
AssertInFastLLM(vars.size() >= 3, "Jinja Error: expression error.");
JinjaVar a = vars[vars.size() - 3], b = vars[vars.size() - 2], c = vars.back();
if (a.type == JinjaVar::JinjaNone) {
a = local[a];
}
AssertInFastLLM(a.type == JinjaVar::JinjaArray && b.type == JinjaVar::JinjaInt && c.type == JinjaVar::JinjaInt,
"Jinja Error: slice expression error.");
vars.pop_back();
vars.pop_back();
vars.pop_back();
if (c.intValue <= 0)
c.intValue += a.arrayValue.size();
std::vector<JinjaVar> subArray(a.arrayValue.begin() + b.intValue, a.arrayValue.begin() + c.intValue);
vars.push_back(JinjaVar(subArray));
}
}

Expand Down

0 comments on commit 1650218

Please sign in to comment.