diff --git a/docs/models.md b/docs/models.md
index bafa09a9..cb60ef78 100644
--- a/docs/models.md
+++ b/docs/models.md
@@ -50,10 +50,10 @@
| 模型 | 加载后转换 | 离线转换 | 直接读取 |
|-------------------: |------------|------------|------------|
-| Qwen/Qwen-7B-Chat | [✔](#其它模型) | [✔](#qwen模型导出) | |
-| Qwen/Qwen-14B-Chat | [✔](#其它模型) | [✔](#qwen模型导出) | |
-| Qwen/Qwen-72B-Chat | [✔](#其它模型) | [✔](#qwen模型导出) | |
-| Qwen/Qwen-1_8B-Chat | [✔](#其它模型) | [✔](#qwen模型导出) | |
+| Qwen/Qwen-7B-Chat | [✔](#其它模型) | [✔](#qwen模型导出) | ✔ |
+| Qwen/Qwen-14B-Chat | [✔](#其它模型) | [✔](#qwen模型导出) | ✔ |
+| Qwen/Qwen-72B-Chat | [✔](#其它模型) | [✔](#qwen模型导出) | √ |
+| Qwen/Qwen-1_8B-Chat | [✔](#其它模型) | [✔](#qwen模型导出) | ✔ |
| Qwen/Qwen1.5-0.5B-Chat | [✔](#其它模型) | [✔](#qwen模型导出) | ✔3 |
| Qwen/Qwen1.5-1.8B-Chat | [✔](#其它模型) | [✔](#qwen模型导出) | ✔3 |
| Qwen/Qwen1.5-4B-Chat | [✔](#其它模型) | [✔](#qwen模型导出) | ✔3 |
@@ -99,8 +99,8 @@
|-----------------: |------------|------------|------------|
| meta-llama/Llama-2-7b-chat-hf | [✔](llama_cookbook.md#llama2-chat) | [✔](llama_cookbook.md#llama2-chat) | |
| meta-llama/Llama-2-13b-chat-hf | [✔](llama_cookbook.md#llama2-chat) | [✔](llama_cookbook.md#llama2-chat) | |
-| codellama/CodeLlama-7b-Instruct-hf | [✔](llama_cookbook.md#llama2-chat) | [✔](llama_cookbook.md#llama2-chat) | |
-| codellama/CodeLlama-13b-Instruct-hf | [✔](llama_cookbook.md#llama2-chat) | [✔](llama_cookbook.md#llama2-chat) | |
+| codellama/CodeLlama-7b-Instruct-hf | [✔](llama_cookbook.md#llama2-chat) | [✔](llama_cookbook.md#llama2-chat) | ✔ |
+| codellama/CodeLlama-13b-Instruct-hf | [✔](llama_cookbook.md#llama2-chat) | [✔](llama_cookbook.md#llama2-chat) | ✔ |
| xverse/XVERSE-13B-Chat | [✔](llama_cookbook.md#xverse) | [✔](llama_cookbook.md#xverse) | |
| xverse/XVERSE-7B-Chat | [✔](llama_cookbook.md#xverse) | [✔](llama_cookbook.md#xverse) | |
| | | | |
@@ -130,6 +130,9 @@
| | | | |
| openbmb/MiniCPM-2B-sft-fp16 | [✔](#其它模型) | [✔](#minicpm模型导出) | |
| openbmb/MiniCPM-2B-dpo-fp16 | [✔](#其它模型) | [✔](#minicpm模型导出) | |
+| openbmb/MiniCPM3-4B | [✔](#其它模型) | [✔](#minicpm模型导出) | |
+| | | | |
+| microsoft/Phi-3-mini-4k-instruct | | | ✔ |
### 加载后转换(两行加速模式)(convert on-the-fly)
@@ -265,6 +268,7 @@ python3 tools/llamalike2flm.py qwen1.5-7b-int4.flm int4 "qwen/Qwen1.5-14B-Chat"
# 需要先安装MiniCPM环境(transformers >= 4.36.0)
# 默认脚本导出iniCPM-2B-dpo-fp16模型
cd build
-python tools/minicpm2flm.py minicpm-2b-float16.flm #导出dpo-float16模型
+python tools/minicpm2flm.py minicpm-2b-fp16.flm #导出dpo-float16模型
+python tools/minicpm2flm.py minicpm3-4b-fp16.flm openbmb/MiniCPM3-4B #导出minicpm3-float16模型
./main -p minicpm-2b-float16.flm # 执行模型
```
\ No newline at end of file
diff --git a/example/Android/LLMAssistant/app/src/main/cpp/CMakeLists.txt b/example/Android/LLMAssistant/app/src/main/cpp/CMakeLists.txt
index 9879ddc7..0b852915 100644
--- a/example/Android/LLMAssistant/app/src/main/cpp/CMakeLists.txt
+++ b/example/Android/LLMAssistant/app/src/main/cpp/CMakeLists.txt
@@ -33,17 +33,23 @@ set(PROJECT_SOURCE
../../../../../../../src/template.cpp
../../../../../../../src/devices/cpu/cpudevice.cpp
../../../../../../../src/devices/cpu/cpudevicebatch.cpp
+ ../../../../../../../src/devices/cpu/linear.cpp
../../../../../../../src/models/basellm.cpp
../../../../../../../src/models/bert.cpp
../../../../../../../src/models/chatglm.cpp
+ ../../../../../../../src/models/cogvlm.cpp
../../../../../../../src/models/deepseekv2.cpp
- ../../../../../../../src/models/moss.cpp
- ../../../../../../../src/models/llama.cpp
- ../../../../../../../src/models/qwen.cpp
../../../../../../../src/models/glm.cpp
+ ../../../../../../../src/models/graphllm.cpp
+ ../../../../../../../src/models/llama.cpp
../../../../../../../src/models/internlm2.cpp
../../../../../../../src/models/minicpm.cpp
+ ../../../../../../../src/models/minicpm3.cpp
../../../../../../../src/models/moe.cpp
+ ../../../../../../../src/models/moss.cpp
+ ../../../../../../../src/models/phi3.cpp
+ ../../../../../../../src/models/qwen.cpp
+ ../../../../../../../src/models/xlmroberta.cpp
)
include_directories(
diff --git a/example/Win32Demo/fastllm-gpu.vcxproj b/example/Win32Demo/fastllm-gpu.vcxproj
index 8afd2f46..f0145dc9 100644
--- a/example/Win32Demo/fastllm-gpu.vcxproj
+++ b/example/Win32Demo/fastllm-gpu.vcxproj
@@ -196,9 +196,11 @@
+
+
@@ -206,6 +208,7 @@
+
@@ -213,9 +216,12 @@
+
+
+
@@ -225,8 +231,10 @@
+
+
@@ -234,18 +242,23 @@
+
+
+
+
+
@@ -253,9 +266,15 @@
Document
+
+ Document
+
Document
+
+ Document
+
\ No newline at end of file
diff --git a/example/Win32Demo/fastllm-gpu.vcxproj.filters b/example/Win32Demo/fastllm-gpu.vcxproj.filters
index ebea6d5b..d3768f17 100644
--- a/example/Win32Demo/fastllm-gpu.vcxproj.filters
+++ b/example/Win32Demo/fastllm-gpu.vcxproj.filters
@@ -49,6 +49,12 @@
{385ba64f-d978-469c-af0c-498ec33a1bb7}
+
+ {917a368b-3a0e-476b-9f35-6af4433aca16}
+
+
+ {9fbf8989-df11-4dc7-9697-d69d10b7f0bf}
+
@@ -78,6 +84,9 @@
头文件\models
+
+ 头文件\models
+
头文件\models
@@ -99,21 +108,33 @@
头文件\models
+
+ 头文件\models
+
头文件\models
头文件\models
+
+ 头文件\models
+
头文件\models
-
- 头文件\devices\cpu
+
+ 头文件\models
头文件\devices\cpu
+
+ 头文件\devices\cpu
+
+
+ 头文件\devices\cpu
+
头文件\devices\cpu
@@ -129,6 +150,12 @@
头文件\devices\cuda
+
+ 头文件\devices\multicuda
+
+
+ 头文件\devices\multicuda
+
头文件\third_party
@@ -161,6 +188,9 @@
源文件\models
+
+ 源文件\models
+
源文件\models
@@ -179,18 +209,30 @@
源文件\models
+
+ 源文件\models
+
源文件\models
源文件\models
+
+ 源文件\models
+
源文件\models
+
+ 源文件\models
+
源文件\models\graph
+
+ 源文件\models\graph
+
源文件\models\graph
@@ -203,12 +245,18 @@
源文件\devices\cpu
+
+ 源文件\devices\cpu
+
源文件\devices\cuda
源文件\devices\cuda
+
+ 源文件\devices\multicuda
+
源文件\third_party
@@ -217,5 +265,8 @@
源文件\devices\cuda
+
+ 源文件\devices\multicuda
+
\ No newline at end of file
diff --git a/example/Win32Demo/fastllm.vcxproj b/example/Win32Demo/fastllm.vcxproj
index 421b94fd..5ee39467 100644
--- a/example/Win32Demo/fastllm.vcxproj
+++ b/example/Win32Demo/fastllm.vcxproj
@@ -173,6 +173,7 @@
+
@@ -182,6 +183,7 @@
+
@@ -189,9 +191,12 @@
+
+
+
@@ -201,6 +206,7 @@
+
@@ -208,18 +214,23 @@
+
+
+
+
+
diff --git a/example/Win32Demo/fastllm.vcxproj.filters b/example/Win32Demo/fastllm.vcxproj.filters
index 54497dd0..17387b28 100644
--- a/example/Win32Demo/fastllm.vcxproj.filters
+++ b/example/Win32Demo/fastllm.vcxproj.filters
@@ -78,6 +78,9 @@
头文件\models
+
+ 头文件\models
+
头文件\models
@@ -99,15 +102,24 @@
头文件\models
+
+ 头文件\models
+
头文件\models
头文件\models
+
+ 头文件\models
+
头文件\models
+
+ 头文件\models
+
头文件\devices\cpu
@@ -155,6 +167,9 @@
源文件\models
+
+ 源文件\models
+
源文件\models
@@ -173,18 +188,30 @@
源文件\models
+
+ 源文件\models
+
源文件\models
源文件\models
+
+ 源文件\models
+
源文件\models
+
+ 源文件\models
+
源文件\models\graph
+
+ 源文件\models\graph
+
源文件\models\graph
diff --git a/include/template.h b/include/template.h
index e6eb6d76..8f03f11e 100644
--- a/include/template.h
+++ b/include/template.h
@@ -53,9 +53,9 @@ namespace fastllm {
JinjaTokenLMB, JinjaTokenRMB, JinjaTokenLSB, JinjaTokenRSB,
JinjaTokenSet, JinjaTokenFor, JinjaTokenEndFor, JinjaTokenIf, JinjaTokenElse, JinjaTokenElseIf, JinjaTokenEndif,
JinjaTokenIn,
- JinjaTokenAssign, JinjaTokenNotEqual, JinjaTokenEqual, JinjaTokenAdd, JinjaTokenSub, JinjaTokenMul, JinjaTokenDiv,
+ JinjaTokenAssign, JinjaTokenNotEqual, JinjaTokenEqual, JinjaTokenAdd, JinjaTokenSub, JinjaTokenMul, JinjaTokenDiv, JinjaTokenMod,
JinjaTokenNot, JinjaTokenAnd, JinjaTokenOr,
- JinjaTokenFliter, JinjaTokenNamespace
+ JinjaTokenFilter, JinjaTokenNamespace, JinjaTokenSlice
};
JinjaToKenType type;
@@ -74,7 +74,9 @@ namespace fastllm {
{'-', JinjaToken::JinjaToKenType::JinjaTokenSub},
{'*', JinjaToken::JinjaToKenType::JinjaTokenMul},
{'/', JinjaToken::JinjaToKenType::JinjaTokenDiv},
- {'|', JinjaToken::JinjaToKenType::JinjaTokenFliter}
+ {'%', JinjaToken::JinjaToKenType::JinjaTokenMod},
+ {'|', JinjaToken::JinjaToKenType::JinjaTokenFilter},
+ {':', JinjaToken::JinjaToKenType::JinjaTokenSlice}
};
static std::map escapeChars = {
diff --git a/src/model.cpp b/src/model.cpp
index 096f1330..29a03446 100644
--- a/src/model.cpp
+++ b/src/model.cpp
@@ -477,10 +477,13 @@ namespace fastllm {
if (tokenizerClass == "PreTrainedTokenizerFast" || tokenizerClass == "LlamaTokenizerFast"
|| tokenizerClass == "Qwen2Tokenizer"
|| tokenizerClass == "BloomTokenizer"
- || tokenizerClass == "LlamaTokenizer"
+ || tokenizerClass == "LlamaTokenizer" || tokenizerClass == "CodeLlamaTokenizer"
|| tokenizerClass == "MiniCPMTokenizer") {
// PreTrainedTokenizerFast
std::string tokenizerFile = path + "tokenizer.json";
+ if (!fastllm::FileExists(tokenizerFile)) {
+ ErrorInFastLLM("Model with a supported tokenizer_class: " + tokenizerClass + ",but has no \"tokenizer.json\"!");
+ }
auto tokenizer = json11::Json::parse(ReadAllFile(tokenizerFile), error);
for (auto &it : tokenizer["model"]["vocab"].object_items()) {
model->weight.AddTokenizerWord(it.first, it.second.int_value(), 1.0f);
@@ -522,6 +525,18 @@ namespace fastllm {
model->history_sep = "";
model->weight.tokenizer.type = Tokenizer::TokenizerType::QWEN;
model->weight.tokenizer.chatTemplate = "";
+ } else if (tokenizerClass == "QWenTokenizer") {
+ // Qwen用的分词
+ std::vector lines, line;
+ SplitString(ReadAllFile(path + "qwen.tiktoken"), {'\n'}, lines);
+ for (int i = 0; i < lines.size(); i++) {
+ SplitString(lines[i], {' '}, line);
+ model->weight.AddTokenizerWord(Base64Decode(line[0]), atoi(line[1].c_str()), 1.0f);
+ }
+ model->weight.tokenizer.type = Tokenizer::TokenizerType::QWEN;
+ model->weight.tokenizer.chatTemplate = "";
+ model->weight.dicts["im_end_id"] = std::to_string(lines.size() + 1);
+ model->weight.dicts["im_start_id"] = std::to_string(lines.size() + 2);
} else {
ErrorInFastLLM("Unsupport tokenizer_class: " + tokenizerClass);
}
@@ -635,7 +650,7 @@ namespace fastllm {
for (auto &it : generation_config.object_items()) {
if ("eos_token_id" == it.first && it.second.type() == json11::Json::ARRAY)
continue;
- model->weight.AddDict(it.first, it.second.dump().c_str());
+ model->weight.AddDict(it.first, it.second.is_string() ? it.second.string_value() : it.second.dump());
}
// 更新eos_token_id
if (generation_config["eos_token_id"].is_array()) {
diff --git a/src/models/minicpm3.cpp b/src/models/minicpm3.cpp
index ef92c433..8396710f 100644
--- a/src/models/minicpm3.cpp
+++ b/src/models/minicpm3.cpp
@@ -80,6 +80,8 @@ namespace fastllm {
if (this->weight.dicts.find("kv_lora_rank") != this->weight.dicts.end()) {
this->kv_lora_rank = std::stoi(this->weight.dicts["kv_lora_rank"]);
}
+ weight.tokenizer.SetSpecialTokens({{"", 2}, {"", 1}, {"", 0}, {"<|im_start|>", 73441}, {"<|im_end|>", 73440}, {"<|tool_call|>", 73442},
+ {"<|execute_start|>", 73443}, {"<|execute_end|>", 73444}, {"<|fim_prefix|>", 73445}, {"<|fim_middle|>", 73446}, {"<|fim_suffix|>", 73447}});
}
int MiniCpm3Model::Forward(const fastllm::Data &inputIds, const fastllm::Data &attentionMask,
diff --git a/src/models/qwen.cpp b/src/models/qwen.cpp
index 1c1ddfaa..662b9b54 100644
--- a/src/models/qwen.cpp
+++ b/src/models/qwen.cpp
@@ -56,6 +56,11 @@ namespace fastllm {
}
weight.embeddingNames.insert("transformer.wte.weight");
+ weight.linearNames = {
+ "lm_head.weight", "transformer.h.*.ln_1.weight", "transformer.h.*.attn.c_attn.weight",
+ "transformer.h.*.attn.c_proj.weight", "transformer.h.*.ln_2.weight",
+ "transformer.h.*.mlp.w1.weight", "transformer.h.*.mlp.w2.weight", "transformer.h.*.mlp.c_proj.weight"
+ };
}
int QWenModel::Forward(const Data &inputIds,
@@ -166,12 +171,13 @@ namespace fastllm {
// Attention
MatMulTransB(query, pastKey, attnWeights, 1.0 / sqrt(head_dim));
- attnWeights.Reshape({1, attnWeights.dims[0], attnWeights.dims[1], attnWeights.dims[2]});
+ attnWeights.Reshape({batch, -1, attnWeights.dims[1], attnWeights.dims[2]});
if (!attentionMask.dims.empty()) {
AttentionMask(attnWeights, attentionMask, -10000);
}
Softmax(attnWeights, attnWeights, -1);
+ attnWeights.Reshape({1, -1, attnWeights.dims[2], attnWeights.dims[3]});
MatMul(attnWeights, pastValue, attnOutput);
attnOutput.Reshape({attnOutput.dims[1], attnOutput.dims[2], attnOutput.dims[3]});
@@ -460,45 +466,68 @@ namespace fastllm {
Data &inputIds, Data &attentionMask, Data &positionIds) {
int batch = inputTokens.size();
int index = params[0].find("index")->second;
- int promptLen = params[0].find("promptLen")->second;
inputIds.ToDevice(DataDevice::CPU);
attentionMask.ToDevice(DataDevice::CPU);
positionIds.ToDevice(DataDevice::CPU);
+ std::vector seqLens;
+ seqLens.resize(batch);
+ int maxLen = 0;
+ for (int i = 0; i < batch; i++) {
+ int promptLen = params[i].find("promptLen")->second + index;
+ maxLen = std::max(promptLen, maxLen);
+ seqLens[i] = promptLen;
+ }
+
if (index == 0) {
int seqLen = inputTokens[0].size();
- std::vector ids = std::vector(batch * seqLen, 0);
- std::vector vmask = std::vector (batch * seqLen * seqLen, 0);
- std::vector vpids = std::vector(batch * seqLen, 0);
- for (int b = 0; b < batch; b++) {
- for (int i = 0; i < seqLen; i++) {
- ids[b * seqLen + i] = inputTokens[b][i];
+ std::vector ids = std::vector (batch * maxLen, 0);
+ std::vector vpids = std::vector (batch * maxLen, 0);
+ std::vector vmask = std::vector (batch * maxLen * maxLen, 0);
+ for (int i = 0; i < batch; i++) {
+ auto &tokens = inputTokens[i];
+ int len = tokens.size(), base = maxLen - len;
+ for (int j = 0; j < len; j++) {
+ ids[i * maxLen + base + j] = tokens[j];
}
- }
- for (int i = 0; i < seqLen; i++) {
- vpids[i] = i;
- for (int j = i + 1; j < seqLen; j++) {
- vmask[i * seqLen + j] = 1;
+ for (int j = 0; j < len; j++) {
+ vpids[i * maxLen + base + j] = j;
+ }
+
+ std::fill(vmask.data() + i * maxLen * maxLen,
+ vmask.data() + i * maxLen * maxLen + (maxLen - len) * maxLen, 1.0);
+ for (int j = maxLen - len; j < maxLen; j++) {
+ std::fill(vmask.data() + i * maxLen * maxLen + j * maxLen,
+ vmask.data() + i * maxLen * maxLen + j * maxLen + maxLen - len, 1.0);
+ }
+ for (int j = 0; j < len; j++) {
+ for (int k = j + 1; k < len; k++) {
+ vmask[i * maxLen * maxLen + (base + j) * maxLen + base + k] = 1;
+ }
}
}
- for (int b = 1; b < batch; b++) {
- memcpy(vmask.data() + b * seqLen * seqLen, vmask.data(), seqLen * seqLen * sizeof(float));
- memcpy(vpids.data() + b * seqLen, vpids.data(), seqLen * sizeof(float));
- }
- inputIds.CopyFrom(Data(DataType::FLOAT32, {batch, seqLen}, ids));
- attentionMask.CopyFrom(Data(DataType::FLOAT32, {batch, seqLen, seqLen}, vmask));
- positionIds.CopyFrom(Data(DataType::FLOAT32, {batch, seqLen}, vpids));
+
+ inputIds.CopyFrom(Data(DataType::FLOAT32, {batch, maxLen}, ids));
+ attentionMask.CopyFrom(Data(DataType::FLOAT32, {batch, maxLen, maxLen}, vmask));
+ positionIds.CopyFrom(Data(DataType::FLOAT32, {batch, maxLen}, vpids));
} else {
- std::vector ids = std::vector(batch * 1, 0);
- std::vector vpids = std::vector(batch * 1, 0);
- for (int b = 0; b < batch; b++) {
- ids[b] = inputTokens[b][0];
- vpids[b] = (float) (promptLen + index - 1);
+ maxLen++;
+ std::vector fret;
+ for (int i = 0; i < batch; i++) {
+ fret.push_back(inputTokens[i][0]);
}
- inputIds.CopyFrom(Data(DataType::FLOAT32, {batch, 1}, ids));
- attentionMask.CopyFrom(Data());
- positionIds.CopyFrom(Data(DataType::FLOAT32, {batch, 1}, vpids));
+ std::vector pids = std::vector(batch);
+ std::vector vmasks = std::vector(batch * maxLen, 0.0f);
+ for (int i = 0; i < batch; i++) {
+ pids[i] = seqLens[i] - 1;
+ for (int j = 0; j < maxLen - seqLens[i] - 1; j++) {
+ vmasks[i * maxLen + j] = 1.0f;
+ }
+ }
+ inputIds.CopyFrom(Data(DataType::FLOAT32, {batch, 1}, fret));
+ attentionMask.CopyFrom(Data(DataType::FLOAT32, {batch, 1, maxLen}, vmasks));
+ positionIds.CopyFrom(Data(DataType::FLOAT32, {batch, 1}, pids));
}
}
diff --git a/src/template.cpp b/src/template.cpp
index 3ba85db7..ddacaa6f 100644
--- a/src/template.cpp
+++ b/src/template.cpp
@@ -221,6 +221,12 @@ namespace fastllm {
} else if (op == JinjaToken::JinjaTokenAdd) {
if (a.type == JinjaVar::JinjaString && b.type == JinjaVar::JinjaString) {
return a.stringValue + b.stringValue;
+ } else if (a.type == JinjaVar::JinjaInt && b.type == JinjaVar::JinjaInt) {
+ return a.intValue + b.intValue;
+ }
+ } else if (op == JinjaToken::JinjaTokenMod) {
+ if (a.type == JinjaVar::JinjaInt && b.type == JinjaVar::JinjaInt) {
+ return a.intValue % b.intValue;
}
} else if (op == JinjaToken::JinjaTokenIn) {
return b.dictValue.find(a.stringValue) != b.dictValue.end();
@@ -264,16 +270,18 @@ namespace fastllm {
return -2;
} else if (type == JinjaToken::JinjaTokenNot) {
return -1;
- } else if (type == JinjaToken::JinjaTokenEqual || type == JinjaToken::JinjaTokenNotEqual) {
+ } else if (type == JinjaToken::JinjaTokenEqual || type == JinjaToken::JinjaTokenNotEqual || type == JinjaToken::JinjaTokenIn) {
return 0;
} else if (type == JinjaToken::JinjaTokenAdd || type == JinjaToken::JinjaTokenSub) {
return 1;
- } else if (type == JinjaToken::JinjaTokenMul || type == JinjaToken::JinjaTokenDiv) {
+ } else if (type == JinjaToken::JinjaTokenMul || type == JinjaToken::JinjaTokenDiv || type == JinjaToken::JinjaTokenMod) {
return 2;
- } else if (type == JinjaToken::JinjaTokenFliter) {
+ } else if (type == JinjaToken::JinjaTokenFilter) {
return 3;
} else if (type == JinjaToken::JinjaTokenDOT) {
return 4;
+ } else if (type == JinjaToken::JinjaTokenSlice) {
+ return 5;
} else if (type == JinjaToken::JinjaTokenLSB || type == JinjaToken::JinjaTokenLMB) {
return -5;
} else {
@@ -298,7 +306,15 @@ namespace fastllm {
if (trimNext)
part.erase(part.find_last_not_of(" \n\r\t") + 1);
blocks.push_back(JinjaBlock(part));
- blocks.push_back(temp.substr(i, curEnd + 2 - i));
+ // 处理切片语法糖
+ part = temp.substr(i, curEnd + 2 - i);
+ size_t slicepos = part.find("[:");
+ if (slicepos != std::string::npos)
+ part.replace(slicepos, 2, "[0:");
+ slicepos = part.find(":]");
+ if (slicepos != std::string::npos)
+ part.replace(slicepos, 2, ":0]");
+ blocks.push_back(JinjaBlock(part));
trimNext = (temp[curEnd - 1] == '-');
pos = curEnd + 2;
i = curEnd + 1;
@@ -342,20 +358,23 @@ namespace fastllm {
ops.pop_back();
}
AssertInFastLLM(ops.size() > 0 && ops.back().type == JinjaToken::JinjaTokenLMB, "Error: barckets doesn't match.");
- suffixExp.push_back(tokens[i]);
+ if (suffixExp.back().type != JinjaToken::JinjaTokenSlice)
+ suffixExp.push_back(tokens[i]);
ops.pop_back();
} else if (tokens[i].type == JinjaToken::JinjaTokenDOT ||
tokens[i].type == JinjaToken::JinjaTokenAdd ||
tokens[i].type == JinjaToken::JinjaTokenSub ||
tokens[i].type == JinjaToken::JinjaTokenMul ||
tokens[i].type == JinjaToken::JinjaTokenDiv ||
+ tokens[i].type == JinjaToken::JinjaTokenMod ||
tokens[i].type == JinjaToken::JinjaTokenEqual ||
tokens[i].type == JinjaToken::JinjaTokenNotEqual ||
+ tokens[i].type == JinjaToken::JinjaTokenSlice ||
tokens[i].type == JinjaToken::JinjaTokenIn ||
tokens[i].type == JinjaToken::JinjaTokenAnd ||
tokens[i].type == JinjaToken::JinjaTokenOr ||
tokens[i].type == JinjaToken::JinjaTokenNot ||
- tokens[i].type == JinjaToken::JinjaTokenFliter) {
+ tokens[i].type == JinjaToken::JinjaTokenFilter) {
while (ops.size() > 0 && GetOpLevel(ops.back().type) > GetOpLevel(tokens[i].type)) {
suffixExp.push_back(ops.back());
ops.pop_back();
@@ -412,7 +431,7 @@ namespace fastllm {
vars.pop_back();
vars.pop_back();
vars.push_back(JinjaVar({{a.stringValue, b}}));
- } else if (it.type == JinjaToken::JinjaTokenFliter) {
+ } else if (it.type == JinjaToken::JinjaTokenFilter) {
AssertInFastLLM(vars.size() > 1, "Jinja Error: expression error.");
JinjaVar a = vars[vars.size() - 2], b = vars.back();
if (a.type == JinjaVar::JinjaNone) {
@@ -437,6 +456,7 @@ namespace fastllm {
it.type == JinjaToken::JinjaTokenSub ||
it.type == JinjaToken::JinjaTokenMul ||
it.type == JinjaToken::JinjaTokenDiv ||
+ it.type == JinjaToken::JinjaTokenMod ||
it.type == JinjaToken::JinjaTokenAssign ||
it.type == JinjaToken::JinjaTokenEqual ||
it.type == JinjaToken::JinjaTokenNotEqual ||
@@ -445,7 +465,7 @@ namespace fastllm {
it.type == JinjaToken::JinjaTokenOr) {
AssertInFastLLM(vars.size() > 1, "Jinja Error: expression error.");
JinjaVar a = vars[vars.size() - 2], b = vars.back();
- if (a.type == JinjaVar::JinjaNone) {
+ if (a.type == JinjaVar::JinjaNone && it.type != JinjaToken::JinjaTokenIn) {
a = local[a];
}
if (b.type == JinjaVar::JinjaNone) {
@@ -454,6 +474,21 @@ namespace fastllm {
vars.pop_back();
vars.pop_back();
vars.push_back(JinjaBinaryOp(a, b, it.type));
+ } else if (it.type == JinjaToken::JinjaTokenSlice) {
+ AssertInFastLLM(vars.size() >= 3, "Jinja Error: expression error.");
+ JinjaVar a = vars[vars.size() - 3], b = vars[vars.size() - 2], c = vars.back();
+ if (a.type == JinjaVar::JinjaNone) {
+ a = local[a];
+ }
+ AssertInFastLLM(a.type == JinjaVar::JinjaArray && b.type == JinjaVar::JinjaInt && c.type == JinjaVar::JinjaInt,
+ "Jinja Error: slice expression error.");
+ vars.pop_back();
+ vars.pop_back();
+ vars.pop_back();
+ if (c.intValue <= 0)
+ c.intValue += a.arrayValue.size();
+ std::vector subArray(a.arrayValue.begin() + b.intValue, a.arrayValue.begin() + c.intValue);
+ vars.push_back(JinjaVar(subArray));
}
}
diff --git a/tools/fastllm_pytools/hf_model.py b/tools/fastllm_pytools/hf_model.py
index 8761dfee..ed646565 100644
--- a/tools/fastllm_pytools/hf_model.py
+++ b/tools/fastllm_pytools/hf_model.py
@@ -101,8 +101,8 @@ def create(model,
modelInfo["tokenizer_class"] = tokenizer.name;
if "rope_scaling" in modelInfo and isinstance(modelInfo["rope_scaling"], builtins.dict):
rope_scaling = modelInfo.pop("rope_scaling")
- modelInfo["rope_scaling.type"] = rope_scaling["type"]
- modelInfo["rope_scaling.factor"] = rope_scaling["factor"]
+ for key, value in rope_scaling.items():
+ modelInfo["rope_scaling." + key] = value
if eos_id:
modelInfo["eos_token_id"] = str(eos_id)
diff --git a/tools/fastllm_pytools/torch2flm.py b/tools/fastllm_pytools/torch2flm.py
index 523fe6b5..7b430970 100644
--- a/tools/fastllm_pytools/torch2flm.py
+++ b/tools/fastllm_pytools/torch2flm.py
@@ -186,8 +186,8 @@ def tofile(exportPath,
modelInfo["tokenizer_class"] = tokenizer.name;
if "rope_scaling" in modelInfo and isinstance(modelInfo["rope_scaling"], builtins.dict):
rope_scaling = modelInfo.pop("rope_scaling")
- modelInfo["rope_scaling.type"] = rope_scaling["type"]
- modelInfo["rope_scaling.factor"] = rope_scaling["factor"]
+ for key, value in rope_scaling.items():
+ modelInfo["rope_scaling." + key] = value
if eos_id:
modelInfo["eos_token_id"] = str(eos_id)
diff --git a/tools/scripts/chatglm_export.py b/tools/scripts/chatglm_export.py
index 2be62d37..0657e85b 100644
--- a/tools/scripts/chatglm_export.py
+++ b/tools/scripts/chatglm_export.py
@@ -3,8 +3,9 @@
from ftllm import torch2flm
if __name__ == "__main__":
- tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
- model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
+ modelNameOrPath = sys.argv[3] if len(sys.argv) >= 4 else 'THUDM/chatglm2-6b'
+ tokenizer = AutoTokenizer.from_pretrained(modelNameOrPath, trust_remote_code=True)
+ model = AutoModel.from_pretrained(modelNameOrPath, trust_remote_code=True)
model = model.eval()
dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16"
diff --git a/tools/scripts/minicpm2flm.py b/tools/scripts/minicpm2flm.py
index ae08a2ec..76ff355f 100644
--- a/tools/scripts/minicpm2flm.py
+++ b/tools/scripts/minicpm2flm.py
@@ -10,10 +10,13 @@
model = AutoModelForCausalLM.from_pretrained(modelNameOrPath, trust_remote_code=True, torch_dtype=torch.float16)
model = model.eval()
- model.config.__dict__['model_type'] = 'minicpm'
-
dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16"
exportPath = sys.argv[1] if len(sys.argv) >= 2 else "minicpm-2b-" + dtype + ".flm"
- torch2flm.tofile(exportPath, model, tokenizer, pre_prompt = "",
- user_role = "<用户>", bot_role = "",
- history_sep = "", dtype = dtype)
+
+ if model.config.architectures == ["MiniCPMForCausalLM"]:
+ model.config.model_type = "minicpm"
+ torch2flm.tofile(exportPath, model, tokenizer, pre_prompt = "", user_role = "<用户>",
+ bot_role = "", history_sep = "", dtype = dtype)
+ else:
+ torch2flm.tofile(exportPath, model, tokenizer, pre_prompt="", user_role="<|im_start|>user\n",
+ bot_role="<|im_end|>\n<|im_start|>assistant\n", history_sep="<|im_end|>\n", eos_id = tokenizer.eos_token_id, dtype = dtype)