From 02de5cb85d4e493b7355d5588c573c957edc95fc Mon Sep 17 00:00:00 2001 From: cgli Date: Mon, 25 Nov 2024 21:39:05 +0800 Subject: [PATCH 1/5] =?UTF-8?q?=E4=BF=AE=E5=A4=8DWin32Demo=E7=BC=96?= =?UTF-8?q?=E8=AF=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../app/src/main/cpp/CMakeLists.txt | 12 +++- example/Win32Demo/fastllm-gpu.vcxproj | 19 +++++++ example/Win32Demo/fastllm-gpu.vcxproj.filters | 55 ++++++++++++++++++- example/Win32Demo/fastllm.vcxproj | 11 ++++ example/Win32Demo/fastllm.vcxproj.filters | 27 +++++++++ 5 files changed, 119 insertions(+), 5 deletions(-) diff --git a/example/Android/LLMAssistant/app/src/main/cpp/CMakeLists.txt b/example/Android/LLMAssistant/app/src/main/cpp/CMakeLists.txt index 9879ddc7..0b852915 100644 --- a/example/Android/LLMAssistant/app/src/main/cpp/CMakeLists.txt +++ b/example/Android/LLMAssistant/app/src/main/cpp/CMakeLists.txt @@ -33,17 +33,23 @@ set(PROJECT_SOURCE ../../../../../../../src/template.cpp ../../../../../../../src/devices/cpu/cpudevice.cpp ../../../../../../../src/devices/cpu/cpudevicebatch.cpp + ../../../../../../../src/devices/cpu/linear.cpp ../../../../../../../src/models/basellm.cpp ../../../../../../../src/models/bert.cpp ../../../../../../../src/models/chatglm.cpp + ../../../../../../../src/models/cogvlm.cpp ../../../../../../../src/models/deepseekv2.cpp - ../../../../../../../src/models/moss.cpp - ../../../../../../../src/models/llama.cpp - ../../../../../../../src/models/qwen.cpp ../../../../../../../src/models/glm.cpp + ../../../../../../../src/models/graphllm.cpp + ../../../../../../../src/models/llama.cpp ../../../../../../../src/models/internlm2.cpp ../../../../../../../src/models/minicpm.cpp + ../../../../../../../src/models/minicpm3.cpp ../../../../../../../src/models/moe.cpp + ../../../../../../../src/models/moss.cpp + ../../../../../../../src/models/phi3.cpp + ../../../../../../../src/models/qwen.cpp + ../../../../../../../src/models/xlmroberta.cpp ) include_directories( diff --git a/example/Win32Demo/fastllm-gpu.vcxproj b/example/Win32Demo/fastllm-gpu.vcxproj index 8afd2f46..f0145dc9 100644 --- a/example/Win32Demo/fastllm-gpu.vcxproj +++ b/example/Win32Demo/fastllm-gpu.vcxproj @@ -196,9 +196,11 @@ + + @@ -206,6 +208,7 @@ + @@ -213,9 +216,12 @@ + + + @@ -225,8 +231,10 @@ + + @@ -234,18 +242,23 @@ + + + + + @@ -253,9 +266,15 @@ Document + + Document + Document + + Document + \ No newline at end of file diff --git a/example/Win32Demo/fastllm-gpu.vcxproj.filters b/example/Win32Demo/fastllm-gpu.vcxproj.filters index ebea6d5b..d3768f17 100644 --- a/example/Win32Demo/fastllm-gpu.vcxproj.filters +++ b/example/Win32Demo/fastllm-gpu.vcxproj.filters @@ -49,6 +49,12 @@ {385ba64f-d978-469c-af0c-498ec33a1bb7} + + {917a368b-3a0e-476b-9f35-6af4433aca16} + + + {9fbf8989-df11-4dc7-9697-d69d10b7f0bf} + @@ -78,6 +84,9 @@ 头文件\models + + 头文件\models + 头文件\models @@ -99,21 +108,33 @@ 头文件\models + + 头文件\models + 头文件\models 头文件\models + + 头文件\models + 头文件\models - - 头文件\devices\cpu + + 头文件\models 头文件\devices\cpu + + 头文件\devices\cpu + + + 头文件\devices\cpu + 头文件\devices\cpu @@ -129,6 +150,12 @@ 头文件\devices\cuda + + 头文件\devices\multicuda + + + 头文件\devices\multicuda + 头文件\third_party @@ -161,6 +188,9 @@ 源文件\models + + 源文件\models + 源文件\models @@ -179,18 +209,30 @@ 源文件\models + + 源文件\models + 源文件\models 源文件\models + + 源文件\models + 源文件\models + + 源文件\models + 源文件\models\graph + + 源文件\models\graph + 源文件\models\graph @@ -203,12 +245,18 @@ 源文件\devices\cpu + + 源文件\devices\cpu + 源文件\devices\cuda 源文件\devices\cuda + + 源文件\devices\multicuda + 源文件\third_party @@ -217,5 +265,8 @@ 源文件\devices\cuda + + 源文件\devices\multicuda + \ No newline at end of file diff --git a/example/Win32Demo/fastllm.vcxproj b/example/Win32Demo/fastllm.vcxproj index 421b94fd..5ee39467 100644 --- a/example/Win32Demo/fastllm.vcxproj +++ b/example/Win32Demo/fastllm.vcxproj @@ -173,6 +173,7 @@ + @@ -182,6 +183,7 @@ + @@ -189,9 +191,12 @@ + + + @@ -201,6 +206,7 @@ + @@ -208,18 +214,23 @@ + + + + + diff --git a/example/Win32Demo/fastllm.vcxproj.filters b/example/Win32Demo/fastllm.vcxproj.filters index 54497dd0..17387b28 100644 --- a/example/Win32Demo/fastllm.vcxproj.filters +++ b/example/Win32Demo/fastllm.vcxproj.filters @@ -78,6 +78,9 @@ 头文件\models + + 头文件\models + 头文件\models @@ -99,15 +102,24 @@ 头文件\models + + 头文件\models + 头文件\models 头文件\models + + 头文件\models + 头文件\models + + 头文件\models + 头文件\devices\cpu @@ -155,6 +167,9 @@ 源文件\models + + 源文件\models + 源文件\models @@ -173,18 +188,30 @@ 源文件\models + + 源文件\models + 源文件\models 源文件\models + + 源文件\models + 源文件\models + + 源文件\models + 源文件\models\graph + + 源文件\models\graph + 源文件\models\graph From e6d4e563215ed639f0d9466766e12abd4ea3dac8 Mon Sep 17 00:00:00 2001 From: cgli Date: Sun, 1 Dec 2024 12:43:53 +0800 Subject: [PATCH 2/5] =?UTF-8?q?=E6=94=AF=E6=8C=81MiniCPM3-4B=E5=AF=BC?= =?UTF-8?q?=E5=87=BA=E5=B9=B6=E4=BD=BF=E7=94=A8C++=20Tokenizer?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/models.md | 6 +++++- src/models/minicpm3.cpp | 2 ++ tools/fastllm_pytools/hf_model.py | 4 ++-- tools/fastllm_pytools/torch2flm.py | 4 ++-- tools/scripts/chatglm_export.py | 5 +++-- tools/scripts/minicpm2flm.py | 13 ++++++++----- 6 files changed, 22 insertions(+), 12 deletions(-) diff --git a/docs/models.md b/docs/models.md index bafa09a9..17929b40 100644 --- a/docs/models.md +++ b/docs/models.md @@ -130,6 +130,9 @@ | | | | | | openbmb/MiniCPM-2B-sft-fp16 | [✔](#其它模型) | [✔](#minicpm模型导出) | | | openbmb/MiniCPM-2B-dpo-fp16 | [✔](#其它模型) | [✔](#minicpm模型导出) | | +| openbmb/MiniCPM3-4B | [✔](#其它模型) | [✔](#minicpm模型导出) | | +| | | | | +| microsoft/Phi-3-mini-4k-instruct | | | ✔ | ### 加载后转换(两行加速模式)(convert on-the-fly) @@ -265,6 +268,7 @@ python3 tools/llamalike2flm.py qwen1.5-7b-int4.flm int4 "qwen/Qwen1.5-14B-Chat" # 需要先安装MiniCPM环境(transformers >= 4.36.0) # 默认脚本导出iniCPM-2B-dpo-fp16模型 cd build -python tools/minicpm2flm.py minicpm-2b-float16.flm #导出dpo-float16模型 +python tools/minicpm2flm.py minicpm-2b-fp16.flm #导出dpo-float16模型 +python tools/minicpm2flm.py minicpm3-4b-fp16.flm openbmb/MiniCPM3-4B #导出minicpm3-float16模型 ./main -p minicpm-2b-float16.flm # 执行模型 ``` \ No newline at end of file diff --git a/src/models/minicpm3.cpp b/src/models/minicpm3.cpp index ef92c433..8396710f 100644 --- a/src/models/minicpm3.cpp +++ b/src/models/minicpm3.cpp @@ -80,6 +80,8 @@ namespace fastllm { if (this->weight.dicts.find("kv_lora_rank") != this->weight.dicts.end()) { this->kv_lora_rank = std::stoi(this->weight.dicts["kv_lora_rank"]); } + weight.tokenizer.SetSpecialTokens({{"", 2}, {"", 1}, {"", 0}, {"<|im_start|>", 73441}, {"<|im_end|>", 73440}, {"<|tool_call|>", 73442}, + {"<|execute_start|>", 73443}, {"<|execute_end|>", 73444}, {"<|fim_prefix|>", 73445}, {"<|fim_middle|>", 73446}, {"<|fim_suffix|>", 73447}}); } int MiniCpm3Model::Forward(const fastllm::Data &inputIds, const fastllm::Data &attentionMask, diff --git a/tools/fastllm_pytools/hf_model.py b/tools/fastllm_pytools/hf_model.py index 8761dfee..ed646565 100644 --- a/tools/fastllm_pytools/hf_model.py +++ b/tools/fastllm_pytools/hf_model.py @@ -101,8 +101,8 @@ def create(model, modelInfo["tokenizer_class"] = tokenizer.name; if "rope_scaling" in modelInfo and isinstance(modelInfo["rope_scaling"], builtins.dict): rope_scaling = modelInfo.pop("rope_scaling") - modelInfo["rope_scaling.type"] = rope_scaling["type"] - modelInfo["rope_scaling.factor"] = rope_scaling["factor"] + for key, value in rope_scaling.items(): + modelInfo["rope_scaling." + key] = value if eos_id: modelInfo["eos_token_id"] = str(eos_id) diff --git a/tools/fastllm_pytools/torch2flm.py b/tools/fastllm_pytools/torch2flm.py index 523fe6b5..7b430970 100644 --- a/tools/fastllm_pytools/torch2flm.py +++ b/tools/fastllm_pytools/torch2flm.py @@ -186,8 +186,8 @@ def tofile(exportPath, modelInfo["tokenizer_class"] = tokenizer.name; if "rope_scaling" in modelInfo and isinstance(modelInfo["rope_scaling"], builtins.dict): rope_scaling = modelInfo.pop("rope_scaling") - modelInfo["rope_scaling.type"] = rope_scaling["type"] - modelInfo["rope_scaling.factor"] = rope_scaling["factor"] + for key, value in rope_scaling.items(): + modelInfo["rope_scaling." + key] = value if eos_id: modelInfo["eos_token_id"] = str(eos_id) diff --git a/tools/scripts/chatglm_export.py b/tools/scripts/chatglm_export.py index 2be62d37..0657e85b 100644 --- a/tools/scripts/chatglm_export.py +++ b/tools/scripts/chatglm_export.py @@ -3,8 +3,9 @@ from ftllm import torch2flm if __name__ == "__main__": - tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) - model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) + modelNameOrPath = sys.argv[3] if len(sys.argv) >= 4 else 'THUDM/chatglm2-6b' + tokenizer = AutoTokenizer.from_pretrained(modelNameOrPath, trust_remote_code=True) + model = AutoModel.from_pretrained(modelNameOrPath, trust_remote_code=True) model = model.eval() dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16" diff --git a/tools/scripts/minicpm2flm.py b/tools/scripts/minicpm2flm.py index ae08a2ec..76ff355f 100644 --- a/tools/scripts/minicpm2flm.py +++ b/tools/scripts/minicpm2flm.py @@ -10,10 +10,13 @@ model = AutoModelForCausalLM.from_pretrained(modelNameOrPath, trust_remote_code=True, torch_dtype=torch.float16) model = model.eval() - model.config.__dict__['model_type'] = 'minicpm' - dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16" exportPath = sys.argv[1] if len(sys.argv) >= 2 else "minicpm-2b-" + dtype + ".flm" - torch2flm.tofile(exportPath, model, tokenizer, pre_prompt = "", - user_role = "<用户>", bot_role = "", - history_sep = "", dtype = dtype) + + if model.config.architectures == ["MiniCPMForCausalLM"]: + model.config.model_type = "minicpm" + torch2flm.tofile(exportPath, model, tokenizer, pre_prompt = "", user_role = "<用户>", + bot_role = "", history_sep = "", dtype = dtype) + else: + torch2flm.tofile(exportPath, model, tokenizer, pre_prompt="", user_role="<|im_start|>user\n", + bot_role="<|im_end|>\n<|im_start|>assistant\n", history_sep="<|im_end|>\n", eos_id = tokenizer.eos_token_id, dtype = dtype) From 8c7eb27c5345280c67c2d0890560f4f2b0d75291 Mon Sep 17 00:00:00 2001 From: cgli Date: Sun, 1 Dec 2024 21:45:51 +0800 Subject: [PATCH 3/5] =?UTF-8?q?=E6=94=AF=E6=8C=81=E7=9B=B4=E6=8E=A5?= =?UTF-8?q?=E5=8A=A0=E8=BD=BDQwen=EF=BC=88=E4=B8=80=E4=BB=A3=EF=BC=89?= =?UTF-8?q?=E7=9A=84HF=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/models.md | 8 ++++---- src/model.cpp | 14 +++++++++++++- src/models/qwen.cpp | 5 +++++ 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/docs/models.md b/docs/models.md index 17929b40..8860da77 100644 --- a/docs/models.md +++ b/docs/models.md @@ -50,10 +50,10 @@ | 模型 | 加载后转换 | 离线转换 | 直接读取 | |-------------------: |------------|------------|------------| -| Qwen/Qwen-7B-Chat | [✔](#其它模型) | [✔](#qwen模型导出) | | -| Qwen/Qwen-14B-Chat | [✔](#其它模型) | [✔](#qwen模型导出) | | -| Qwen/Qwen-72B-Chat | [✔](#其它模型) | [✔](#qwen模型导出) | | -| Qwen/Qwen-1_8B-Chat | [✔](#其它模型) | [✔](#qwen模型导出) | | +| Qwen/Qwen-7B-Chat | [✔](#其它模型) | [✔](#qwen模型导出) | ✔ | +| Qwen/Qwen-14B-Chat | [✔](#其它模型) | [✔](#qwen模型导出) | ✔ | +| Qwen/Qwen-72B-Chat | [✔](#其它模型) | [✔](#qwen模型导出) | √ | +| Qwen/Qwen-1_8B-Chat | [✔](#其它模型) | [✔](#qwen模型导出) | ✔ | | Qwen/Qwen1.5-0.5B-Chat | [✔](#其它模型) | [✔](#qwen模型导出) | ✔3 | | Qwen/Qwen1.5-1.8B-Chat | [✔](#其它模型) | [✔](#qwen模型导出) | ✔3 | | Qwen/Qwen1.5-4B-Chat | [✔](#其它模型) | [✔](#qwen模型导出) | ✔3 | diff --git a/src/model.cpp b/src/model.cpp index 2314ba7e..c118ccee 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -524,6 +524,18 @@ namespace fastllm { model->history_sep = ""; model->weight.tokenizer.type = Tokenizer::TokenizerType::QWEN; model->weight.tokenizer.chatTemplate = ""; + } else if (tokenizerClass == "QWenTokenizer") { + // Qwen用的分词 + std::vector lines, line; + SplitString(ReadAllFile(path + "qwen.tiktoken"), {'\n'}, lines); + for (int i = 0; i < lines.size(); i++) { + SplitString(lines[i], {' '}, line); + model->weight.AddTokenizerWord(Base64Decode(line[0]), atoi(line[1].c_str()), 1.0f); + } + model->weight.tokenizer.type = Tokenizer::TokenizerType::QWEN; + model->weight.tokenizer.chatTemplate = ""; + model->weight.dicts["im_end_id"] = std::to_string(lines.size() + 1); + model->weight.dicts["im_start_id"] = std::to_string(lines.size() + 2); } else { ErrorInFastLLM("Unsupport tokenizer_class: " + tokenizerClass); } @@ -637,7 +649,7 @@ namespace fastllm { for (auto &it : generation_config.object_items()) { if ("eos_token_id" == it.first && it.second.type() == json11::Json::ARRAY) continue; - model->weight.AddDict(it.first, it.second.dump().c_str()); + model->weight.AddDict(it.first, it.second.is_string() ? it.second.string_value() : it.second.dump()); } // 更新eos_token_id if (generation_config["eos_token_id"].is_array()) { diff --git a/src/models/qwen.cpp b/src/models/qwen.cpp index 1c1ddfaa..abad7c33 100644 --- a/src/models/qwen.cpp +++ b/src/models/qwen.cpp @@ -56,6 +56,11 @@ namespace fastllm { } weight.embeddingNames.insert("transformer.wte.weight"); + weight.linearNames = { + "lm_head.weight", "transformer.h.*.ln_1.weight", "transformer.h.*.attn.c_attn.weight", + "transformer.h.*.attn.c_proj.weight", "transformer.h.*.ln_2.weight", + "transformer.h.*.mlp.w1.weight", "transformer.h.*.mlp.w2.weight", "transformer.h.*.mlp.c_proj.weight" + }; } int QWenModel::Forward(const Data &inputIds, From 3b2220dd861d2b3cea82e5faf9fdf0683efcd1fe Mon Sep 17 00:00:00 2001 From: cgli Date: Sat, 27 Apr 2024 16:58:36 +0800 Subject: [PATCH 4/5] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dqwen=E6=89=B9=E9=87=8F?= =?UTF-8?q?=E6=8E=A8=E7=90=86=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/models/qwen.cpp | 80 +++++++++++++++++++++++++++++---------------- 1 file changed, 52 insertions(+), 28 deletions(-) diff --git a/src/models/qwen.cpp b/src/models/qwen.cpp index abad7c33..662b9b54 100644 --- a/src/models/qwen.cpp +++ b/src/models/qwen.cpp @@ -171,12 +171,13 @@ namespace fastllm { // Attention MatMulTransB(query, pastKey, attnWeights, 1.0 / sqrt(head_dim)); - attnWeights.Reshape({1, attnWeights.dims[0], attnWeights.dims[1], attnWeights.dims[2]}); + attnWeights.Reshape({batch, -1, attnWeights.dims[1], attnWeights.dims[2]}); if (!attentionMask.dims.empty()) { AttentionMask(attnWeights, attentionMask, -10000); } Softmax(attnWeights, attnWeights, -1); + attnWeights.Reshape({1, -1, attnWeights.dims[2], attnWeights.dims[3]}); MatMul(attnWeights, pastValue, attnOutput); attnOutput.Reshape({attnOutput.dims[1], attnOutput.dims[2], attnOutput.dims[3]}); @@ -465,45 +466,68 @@ namespace fastllm { Data &inputIds, Data &attentionMask, Data &positionIds) { int batch = inputTokens.size(); int index = params[0].find("index")->second; - int promptLen = params[0].find("promptLen")->second; inputIds.ToDevice(DataDevice::CPU); attentionMask.ToDevice(DataDevice::CPU); positionIds.ToDevice(DataDevice::CPU); + std::vector seqLens; + seqLens.resize(batch); + int maxLen = 0; + for (int i = 0; i < batch; i++) { + int promptLen = params[i].find("promptLen")->second + index; + maxLen = std::max(promptLen, maxLen); + seqLens[i] = promptLen; + } + if (index == 0) { int seqLen = inputTokens[0].size(); - std::vector ids = std::vector(batch * seqLen, 0); - std::vector vmask = std::vector (batch * seqLen * seqLen, 0); - std::vector vpids = std::vector(batch * seqLen, 0); - for (int b = 0; b < batch; b++) { - for (int i = 0; i < seqLen; i++) { - ids[b * seqLen + i] = inputTokens[b][i]; + std::vector ids = std::vector (batch * maxLen, 0); + std::vector vpids = std::vector (batch * maxLen, 0); + std::vector vmask = std::vector (batch * maxLen * maxLen, 0); + for (int i = 0; i < batch; i++) { + auto &tokens = inputTokens[i]; + int len = tokens.size(), base = maxLen - len; + for (int j = 0; j < len; j++) { + ids[i * maxLen + base + j] = tokens[j]; } - } - for (int i = 0; i < seqLen; i++) { - vpids[i] = i; - for (int j = i + 1; j < seqLen; j++) { - vmask[i * seqLen + j] = 1; + for (int j = 0; j < len; j++) { + vpids[i * maxLen + base + j] = j; + } + + std::fill(vmask.data() + i * maxLen * maxLen, + vmask.data() + i * maxLen * maxLen + (maxLen - len) * maxLen, 1.0); + for (int j = maxLen - len; j < maxLen; j++) { + std::fill(vmask.data() + i * maxLen * maxLen + j * maxLen, + vmask.data() + i * maxLen * maxLen + j * maxLen + maxLen - len, 1.0); + } + for (int j = 0; j < len; j++) { + for (int k = j + 1; k < len; k++) { + vmask[i * maxLen * maxLen + (base + j) * maxLen + base + k] = 1; + } } } - for (int b = 1; b < batch; b++) { - memcpy(vmask.data() + b * seqLen * seqLen, vmask.data(), seqLen * seqLen * sizeof(float)); - memcpy(vpids.data() + b * seqLen, vpids.data(), seqLen * sizeof(float)); - } - inputIds.CopyFrom(Data(DataType::FLOAT32, {batch, seqLen}, ids)); - attentionMask.CopyFrom(Data(DataType::FLOAT32, {batch, seqLen, seqLen}, vmask)); - positionIds.CopyFrom(Data(DataType::FLOAT32, {batch, seqLen}, vpids)); + + inputIds.CopyFrom(Data(DataType::FLOAT32, {batch, maxLen}, ids)); + attentionMask.CopyFrom(Data(DataType::FLOAT32, {batch, maxLen, maxLen}, vmask)); + positionIds.CopyFrom(Data(DataType::FLOAT32, {batch, maxLen}, vpids)); } else { - std::vector ids = std::vector(batch * 1, 0); - std::vector vpids = std::vector(batch * 1, 0); - for (int b = 0; b < batch; b++) { - ids[b] = inputTokens[b][0]; - vpids[b] = (float) (promptLen + index - 1); + maxLen++; + std::vector fret; + for (int i = 0; i < batch; i++) { + fret.push_back(inputTokens[i][0]); } - inputIds.CopyFrom(Data(DataType::FLOAT32, {batch, 1}, ids)); - attentionMask.CopyFrom(Data()); - positionIds.CopyFrom(Data(DataType::FLOAT32, {batch, 1}, vpids)); + std::vector pids = std::vector(batch); + std::vector vmasks = std::vector(batch * maxLen, 0.0f); + for (int i = 0; i < batch; i++) { + pids[i] = seqLens[i] - 1; + for (int j = 0; j < maxLen - seqLens[i] - 1; j++) { + vmasks[i * maxLen + j] = 1.0f; + } + } + inputIds.CopyFrom(Data(DataType::FLOAT32, {batch, 1}, fret)); + attentionMask.CopyFrom(Data(DataType::FLOAT32, {batch, 1, maxLen}, vmasks)); + positionIds.CopyFrom(Data(DataType::FLOAT32, {batch, 1}, pids)); } } From 1650218315b0599803aa1fd22b28ea419114a04a Mon Sep 17 00:00:00 2001 From: cgli Date: Thu, 5 Dec 2024 18:51:46 +0800 Subject: [PATCH 5/5] =?UTF-8?q?Jinja2=E6=A8=A1=E6=9D=BF=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=8F=96=E6=A8=A1=E8=BF=90=E7=AE=97=E5=92=8C=E6=95=B0=E7=BB=84?= =?UTF-8?q?=E5=88=87=E7=89=87=EF=BC=8C=E6=94=AF=E6=8C=81=E7=9B=B4=E6=8E=A5?= =?UTF-8?q?=E5=8A=A0=E8=BD=BDcodellama-7b-instruct=E7=9A=84HF=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/models.md | 4 ++-- include/template.h | 8 +++++--- src/model.cpp | 5 ++++- src/template.cpp | 51 ++++++++++++++++++++++++++++++++++++++-------- 4 files changed, 54 insertions(+), 14 deletions(-) diff --git a/docs/models.md b/docs/models.md index 8860da77..cb60ef78 100644 --- a/docs/models.md +++ b/docs/models.md @@ -99,8 +99,8 @@ |-----------------: |------------|------------|------------| | meta-llama/Llama-2-7b-chat-hf | [✔](llama_cookbook.md#llama2-chat) | [✔](llama_cookbook.md#llama2-chat) | | | meta-llama/Llama-2-13b-chat-hf | [✔](llama_cookbook.md#llama2-chat) | [✔](llama_cookbook.md#llama2-chat) | | -| codellama/CodeLlama-7b-Instruct-hf | [✔](llama_cookbook.md#llama2-chat) | [✔](llama_cookbook.md#llama2-chat) | | -| codellama/CodeLlama-13b-Instruct-hf | [✔](llama_cookbook.md#llama2-chat) | [✔](llama_cookbook.md#llama2-chat) | | +| codellama/CodeLlama-7b-Instruct-hf | [✔](llama_cookbook.md#llama2-chat) | [✔](llama_cookbook.md#llama2-chat) | ✔ | +| codellama/CodeLlama-13b-Instruct-hf | [✔](llama_cookbook.md#llama2-chat) | [✔](llama_cookbook.md#llama2-chat) | ✔ | | xverse/XVERSE-13B-Chat | [✔](llama_cookbook.md#xverse) | [✔](llama_cookbook.md#xverse) | | | xverse/XVERSE-7B-Chat | [✔](llama_cookbook.md#xverse) | [✔](llama_cookbook.md#xverse) | | | | | | | diff --git a/include/template.h b/include/template.h index e6eb6d76..8f03f11e 100644 --- a/include/template.h +++ b/include/template.h @@ -53,9 +53,9 @@ namespace fastllm { JinjaTokenLMB, JinjaTokenRMB, JinjaTokenLSB, JinjaTokenRSB, JinjaTokenSet, JinjaTokenFor, JinjaTokenEndFor, JinjaTokenIf, JinjaTokenElse, JinjaTokenElseIf, JinjaTokenEndif, JinjaTokenIn, - JinjaTokenAssign, JinjaTokenNotEqual, JinjaTokenEqual, JinjaTokenAdd, JinjaTokenSub, JinjaTokenMul, JinjaTokenDiv, + JinjaTokenAssign, JinjaTokenNotEqual, JinjaTokenEqual, JinjaTokenAdd, JinjaTokenSub, JinjaTokenMul, JinjaTokenDiv, JinjaTokenMod, JinjaTokenNot, JinjaTokenAnd, JinjaTokenOr, - JinjaTokenFliter, JinjaTokenNamespace + JinjaTokenFilter, JinjaTokenNamespace, JinjaTokenSlice }; JinjaToKenType type; @@ -74,7 +74,9 @@ namespace fastllm { {'-', JinjaToken::JinjaToKenType::JinjaTokenSub}, {'*', JinjaToken::JinjaToKenType::JinjaTokenMul}, {'/', JinjaToken::JinjaToKenType::JinjaTokenDiv}, - {'|', JinjaToken::JinjaToKenType::JinjaTokenFliter} + {'%', JinjaToken::JinjaToKenType::JinjaTokenMod}, + {'|', JinjaToken::JinjaToKenType::JinjaTokenFilter}, + {':', JinjaToken::JinjaToKenType::JinjaTokenSlice} }; static std::map escapeChars = { diff --git a/src/model.cpp b/src/model.cpp index c118ccee..68314906 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -479,10 +479,13 @@ namespace fastllm { if (tokenizerClass == "PreTrainedTokenizerFast" || tokenizerClass == "LlamaTokenizerFast" || tokenizerClass == "Qwen2Tokenizer" || tokenizerClass == "BloomTokenizer" - || tokenizerClass == "LlamaTokenizer" + || tokenizerClass == "LlamaTokenizer" || tokenizerClass == "CodeLlamaTokenizer" || tokenizerClass == "MiniCPMTokenizer") { // PreTrainedTokenizerFast std::string tokenizerFile = path + "tokenizer.json"; + if (!fastllm::FileExists(tokenizerFile)) { + ErrorInFastLLM("Model with a supported tokenizer_class: " + tokenizerClass + ",but has no \"tokenizer.json\"!"); + } auto tokenizer = json11::Json::parse(ReadAllFile(tokenizerFile), error); for (auto &it : tokenizer["model"]["vocab"].object_items()) { model->weight.AddTokenizerWord(it.first, it.second.int_value(), 1.0f); diff --git a/src/template.cpp b/src/template.cpp index 3ba85db7..ddacaa6f 100644 --- a/src/template.cpp +++ b/src/template.cpp @@ -221,6 +221,12 @@ namespace fastllm { } else if (op == JinjaToken::JinjaTokenAdd) { if (a.type == JinjaVar::JinjaString && b.type == JinjaVar::JinjaString) { return a.stringValue + b.stringValue; + } else if (a.type == JinjaVar::JinjaInt && b.type == JinjaVar::JinjaInt) { + return a.intValue + b.intValue; + } + } else if (op == JinjaToken::JinjaTokenMod) { + if (a.type == JinjaVar::JinjaInt && b.type == JinjaVar::JinjaInt) { + return a.intValue % b.intValue; } } else if (op == JinjaToken::JinjaTokenIn) { return b.dictValue.find(a.stringValue) != b.dictValue.end(); @@ -264,16 +270,18 @@ namespace fastllm { return -2; } else if (type == JinjaToken::JinjaTokenNot) { return -1; - } else if (type == JinjaToken::JinjaTokenEqual || type == JinjaToken::JinjaTokenNotEqual) { + } else if (type == JinjaToken::JinjaTokenEqual || type == JinjaToken::JinjaTokenNotEqual || type == JinjaToken::JinjaTokenIn) { return 0; } else if (type == JinjaToken::JinjaTokenAdd || type == JinjaToken::JinjaTokenSub) { return 1; - } else if (type == JinjaToken::JinjaTokenMul || type == JinjaToken::JinjaTokenDiv) { + } else if (type == JinjaToken::JinjaTokenMul || type == JinjaToken::JinjaTokenDiv || type == JinjaToken::JinjaTokenMod) { return 2; - } else if (type == JinjaToken::JinjaTokenFliter) { + } else if (type == JinjaToken::JinjaTokenFilter) { return 3; } else if (type == JinjaToken::JinjaTokenDOT) { return 4; + } else if (type == JinjaToken::JinjaTokenSlice) { + return 5; } else if (type == JinjaToken::JinjaTokenLSB || type == JinjaToken::JinjaTokenLMB) { return -5; } else { @@ -298,7 +306,15 @@ namespace fastllm { if (trimNext) part.erase(part.find_last_not_of(" \n\r\t") + 1); blocks.push_back(JinjaBlock(part)); - blocks.push_back(temp.substr(i, curEnd + 2 - i)); + // 处理切片语法糖 + part = temp.substr(i, curEnd + 2 - i); + size_t slicepos = part.find("[:"); + if (slicepos != std::string::npos) + part.replace(slicepos, 2, "[0:"); + slicepos = part.find(":]"); + if (slicepos != std::string::npos) + part.replace(slicepos, 2, ":0]"); + blocks.push_back(JinjaBlock(part)); trimNext = (temp[curEnd - 1] == '-'); pos = curEnd + 2; i = curEnd + 1; @@ -342,20 +358,23 @@ namespace fastllm { ops.pop_back(); } AssertInFastLLM(ops.size() > 0 && ops.back().type == JinjaToken::JinjaTokenLMB, "Error: barckets doesn't match."); - suffixExp.push_back(tokens[i]); + if (suffixExp.back().type != JinjaToken::JinjaTokenSlice) + suffixExp.push_back(tokens[i]); ops.pop_back(); } else if (tokens[i].type == JinjaToken::JinjaTokenDOT || tokens[i].type == JinjaToken::JinjaTokenAdd || tokens[i].type == JinjaToken::JinjaTokenSub || tokens[i].type == JinjaToken::JinjaTokenMul || tokens[i].type == JinjaToken::JinjaTokenDiv || + tokens[i].type == JinjaToken::JinjaTokenMod || tokens[i].type == JinjaToken::JinjaTokenEqual || tokens[i].type == JinjaToken::JinjaTokenNotEqual || + tokens[i].type == JinjaToken::JinjaTokenSlice || tokens[i].type == JinjaToken::JinjaTokenIn || tokens[i].type == JinjaToken::JinjaTokenAnd || tokens[i].type == JinjaToken::JinjaTokenOr || tokens[i].type == JinjaToken::JinjaTokenNot || - tokens[i].type == JinjaToken::JinjaTokenFliter) { + tokens[i].type == JinjaToken::JinjaTokenFilter) { while (ops.size() > 0 && GetOpLevel(ops.back().type) > GetOpLevel(tokens[i].type)) { suffixExp.push_back(ops.back()); ops.pop_back(); @@ -412,7 +431,7 @@ namespace fastllm { vars.pop_back(); vars.pop_back(); vars.push_back(JinjaVar({{a.stringValue, b}})); - } else if (it.type == JinjaToken::JinjaTokenFliter) { + } else if (it.type == JinjaToken::JinjaTokenFilter) { AssertInFastLLM(vars.size() > 1, "Jinja Error: expression error."); JinjaVar a = vars[vars.size() - 2], b = vars.back(); if (a.type == JinjaVar::JinjaNone) { @@ -437,6 +456,7 @@ namespace fastllm { it.type == JinjaToken::JinjaTokenSub || it.type == JinjaToken::JinjaTokenMul || it.type == JinjaToken::JinjaTokenDiv || + it.type == JinjaToken::JinjaTokenMod || it.type == JinjaToken::JinjaTokenAssign || it.type == JinjaToken::JinjaTokenEqual || it.type == JinjaToken::JinjaTokenNotEqual || @@ -445,7 +465,7 @@ namespace fastllm { it.type == JinjaToken::JinjaTokenOr) { AssertInFastLLM(vars.size() > 1, "Jinja Error: expression error."); JinjaVar a = vars[vars.size() - 2], b = vars.back(); - if (a.type == JinjaVar::JinjaNone) { + if (a.type == JinjaVar::JinjaNone && it.type != JinjaToken::JinjaTokenIn) { a = local[a]; } if (b.type == JinjaVar::JinjaNone) { @@ -454,6 +474,21 @@ namespace fastllm { vars.pop_back(); vars.pop_back(); vars.push_back(JinjaBinaryOp(a, b, it.type)); + } else if (it.type == JinjaToken::JinjaTokenSlice) { + AssertInFastLLM(vars.size() >= 3, "Jinja Error: expression error."); + JinjaVar a = vars[vars.size() - 3], b = vars[vars.size() - 2], c = vars.back(); + if (a.type == JinjaVar::JinjaNone) { + a = local[a]; + } + AssertInFastLLM(a.type == JinjaVar::JinjaArray && b.type == JinjaVar::JinjaInt && c.type == JinjaVar::JinjaInt, + "Jinja Error: slice expression error."); + vars.pop_back(); + vars.pop_back(); + vars.pop_back(); + if (c.intValue <= 0) + c.intValue += a.arrayValue.size(); + std::vector subArray(a.arrayValue.begin() + b.intValue, a.arrayValue.begin() + c.intValue); + vars.push_back(JinjaVar(subArray)); } }