From 5d8daf170c8d7113da1bfea434e3be3d3809d818 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Thu, 6 Jun 2024 07:40:19 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81chatglm4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/models/chatglm.h | 4 ++ src/fastllm.cpp | 16 +++++- src/model.cpp | 104 +++++++++++++++++++++++++++++++++++++-- src/models/chatglm.cpp | 28 ++++++++--- 4 files changed, 139 insertions(+), 13 deletions(-) diff --git a/include/models/chatglm.h b/include/models/chatglm.h index 1059b18a..4f9a03e2 100644 --- a/include/models/chatglm.h +++ b/include/models/chatglm.h @@ -69,10 +69,14 @@ namespace fastllm { void UpdateRotaryPosEmb(float rope_factor); int gmask_token_id; + + std::string tokenizerClass = ""; private: virtual void CausalMask(Data &data, int start) {}; // 因果mask? float rope_factor = 1.0f; + + float layernorm_epsilon = 1e-5; }; } diff --git a/src/fastllm.cpp b/src/fastllm.cpp index 3f964c36..60b6e3d1 100644 --- a/src/fastllm.cpp +++ b/src/fastllm.cpp @@ -1289,7 +1289,21 @@ namespace fastllm { return Data (DataType::FLOAT32, {1, (int)v.size()}, v); } else if (this->type == TokenizerType::QWEN) { std::map specialTokens = {{"<|im_start|>", 151644}, {"<|im_end|>", 151645}, {"<|endoftext|>", 151643}}; - + for (int i = 0; i < ori.size(); i++) { + if (i + 3 < ori.size() && ori[i] == '<' && ori[i + 1] == 'F' && ori[i + 2] == 'L' && ori[i + 3] == 'M') { + if (i + 15 < ori.size() && ori.substr(i, 15) == "= '0' && ori[i] <= '9') { + now = now * 10 + ori[i] - '0'; + i++; + } + specialTokens[""] = now; + continue; + } + } + } + // comment these special tokens for now // for (int i = 0; i < 205; i++) { // specialTokens.insert("<|extra_" + std::to_string(i) + "|>"); diff --git a/src/model.cpp b/src/model.cpp index 7d260a3c..800ce4ea 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -305,6 +305,65 @@ namespace fastllm { } }; + std::string Base64Decode(const std::string &encoded) { + static const std::string base64_chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; + int in_len = encoded.size(); + int i = 0, j = 0, in_ = 0; + char char_array_4[4], char_array_3[3]; + std::string ret = ""; + + while (in_len-- && ( encoded[in_] != '=')) { + char_array_4[i++] = encoded[in_]; in_++; + if (i == 4) { + for (i = 0; i < 4; i++) + char_array_4[i] = base64_chars.find(char_array_4[i]); + char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + for (i = 0; (i < 3); i++) + ret.push_back(char_array_3[i]); + i = 0; + } + } + + if (i) { + for (j = i; j < 4; j++) + char_array_4[j] = 0; + + for (j = 0; j < 4; j++) + char_array_4[j] = base64_chars.find(char_array_4[j]); + + char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (j = 0; (j < i - 1); j++) ret.push_back(char_array_3[j]); + } + + return ret; + } + + void SplitString(const std::string &str, const std::set &chars, std::vector &ret) { + ret.clear(); + std::string now = ""; + for (int i = 0; i < str.size(); i++) { + if (chars.find(str[i]) == chars.end()) { + now += str[i]; + } else { + if (now != "") { + ret.push_back(now); + now = ""; + } + } + } + if (now != "") { + ret.push_back(now); + } + } + // 从hf文件夹读取,仅支持safetensor格式的模型 std::unique_ptr CreateLLMModelFromHF(const std::string &modelPath, DataType linearDataType, int groupCnt) { @@ -314,12 +373,16 @@ namespace fastllm { } // 1. 检查是否有 model.safetensors.index.json,如果有就读取 + std::set stFiles; std::string stIndexFile = path + "model.safetensors.index.json"; std::string error; - auto stIndex = json11::Json::parse(ReadAllFile(stIndexFile), error)["weight_map"]; - std::set stFiles; - for (auto it : stIndex.object_items()) { - stFiles.insert(path + it.second.string_value()); + if (access(stIndexFile.c_str(), R_OK) != 0) { + stFiles.insert(path + "model.safetensors"); + } else { + auto stIndex = json11::Json::parse(ReadAllFile(stIndexFile), error)["weight_map"]; + for (auto it : stIndex.object_items()) { + stFiles.insert(path + it.second.string_value()); + } } SafeTensors safeTensors(stFiles); @@ -355,6 +418,39 @@ namespace fastllm { tokenizer["decoder"]["type"].string_value() == "ByteLevel") { model->weight.tokenizer.byteAsChar = true; } + } else if (tokenizerClass == "ChatGLM4Tokenizer") { + // GLM4御用的分词 + model->bot_role = " "; + std::vector lines, line; + SplitString(ReadAllFile(path + "tokenizer.model"), {'\r', '\n'}, lines); + for (int i = 0; i < lines.size(); i++) { + SplitString(lines[i], {' '}, line); + model->weight.AddTokenizerWord(Base64Decode(line[0]), atoi(line[1].c_str()), 1.0f); + } + std::map spTokens; + for (auto &it : tokenizerConfig["added_tokens_decoder"].object_items()) { + spTokens[it.second["content"].string_value()] = atoi(it.first.c_str()); + } + model->weight.tokenizer.SetSpecialTokens(spTokens); + ((ChatGLMModel*)model)->gmask_token_id = model->weight.tokenizer.GetTokenId("[gMASK]"); + ((ChatGLMModel*)model)->bos_token_id = model->weight.tokenizer.GetTokenId(""); + ((ChatGLMModel*)model)->tokenizerClass = tokenizerClass; + + // 设置eos_token_id + if (config["eos_token_id"].is_array()) { + for (auto &it : config["eos_token_id"].array_items()) { + model->eos_token_ids.insert(it.int_value()); + } + } else { + model->eos_token_id = config["eos_token_id"].int_value(); + } + + // ChatGLM采用拼接token的方法,需要强行指定分割词的TokenID + model->pre_prompt = ""; + model->user_role = ("weight.tokenizer.GetTokenId("<|user|>")) + ">\n"); + model->bot_role = ("weight.tokenizer.GetTokenId("<|assistant|>")) + ">"); + model->history_sep = ""; + model->weight.tokenizer.type = Tokenizer::TokenizerType::QWEN; } else { ErrorInFastLLM("Unsupport tokenizer_class: " + tokenizerClass); } diff --git a/src/models/chatglm.cpp b/src/models/chatglm.cpp index 175e1f03..94ec4739 100644 --- a/src/models/chatglm.cpp +++ b/src/models/chatglm.cpp @@ -69,6 +69,12 @@ namespace fastllm { this->UpdateRotaryPosEmb(1.0f); weight.embeddingNames.insert("transformer.word_embeddings.weight"); weight.embeddingNames.insert("transformer.embedding.word_embeddings.weight"); + + weight.linearNames = { + "*.query_key_value.weight", "*.dense.weight", + "*.mlp.dense_h_to_4h.weight", "*.mlp.dense_4h_to_h.weight", + "lm_head.weight", "transformer.output_layer.weight" + }; } void ChatGLMModel::InitParams() { @@ -77,13 +83,19 @@ namespace fastllm { if (this->weight.dicts.find("gmask_token_id") != this->weight.dicts.end()) { this->gmask_token_id = atoi(this->weight.dicts["gmask_token_id"].c_str()); } - } else if (GetVersion() == 2) { + } else if (GetVersion() == 2 && this->tokenizerClass != "ChatGLM4Tokenizer") { this->gmask_token_id = 64790; this->bos_token_id = 64792; } - if (this->weight.dicts.find("rope_ratio") != this->weight.dicts.end()) { + if (this->weight.dicts.find("rope_ratio") != this->weight.dicts.end()) { UpdateRotaryPosEmb(atof(this->weight.dicts["rope_ratio"].c_str())); } + if (this->weight.dicts.find("layernorm_epsilon") != this->weight.dicts.end()) { + this->layernorm_epsilon = atof(this->weight.dicts["layernorm_epsilon"].c_str()); + } + if (this->weight.dicts.find("seq_length") != this->weight.dicts.end()) { + max_positions = atoi(this->weight.dicts["seq_length"].c_str()); + } } int ChatGLMModel::Forward(const fastllm::Data &inputIds, const fastllm::Data &attentionMask, @@ -143,7 +155,7 @@ namespace fastllm { } else if (version == 2) { std::string inputRMSWeightName = "transformer.encoder.layers." + std::to_string(i) + ".input_layernorm.weight"; - RMSNorm(hiddenStates, weight[inputRMSWeightName], 1e-5, attenInput); + RMSNorm(hiddenStates, weight[inputRMSWeightName], layernorm_epsilon, attenInput); } std::string qkvWeightName = weightPre + std::to_string(i) + weightMiddle + ".query_key_value.weight"; std::string qkvBiasName = weightPre + std::to_string(i) + weightMiddle + ".query_key_value.bias"; @@ -291,7 +303,7 @@ namespace fastllm { std::string postRMSWeightName = "transformer.encoder.layers." + std::to_string(i) + ".post_attention_layernorm.weight"; Mul(hiddenStates, 1.0, temp); - RMSNorm(hiddenStates, weight[postRMSWeightName], 1e-5, mlpInput); + RMSNorm(hiddenStates, weight[postRMSWeightName], this->layernorm_epsilon, mlpInput); // 1.4 MLP std::string fcInKeyName = "transformer.encoder.layers." + std::to_string(i) + ".mlp.dense_h_to_4h"; std::string fcOutKeyName = "transformer.encoder.layers." + std::to_string(i) + ".mlp.dense_4h_to_h"; @@ -325,7 +337,7 @@ namespace fastllm { weight["transformer.final_layernorm.bias"], -1, hiddenStates); Linear(hiddenStates, weight["lm_head.weight"], Data(), logits); } else { - RMSNorm(hiddenStates, weight["transformer.encoder.final_layernorm.weight"], 1e-5, hiddenStates); + RMSNorm(hiddenStates, weight["transformer.encoder.final_layernorm.weight"], this->layernorm_epsilon, hiddenStates); Linear(hiddenStates, weight["transformer.output_layer.weight"], Data(), logits); } @@ -434,7 +446,7 @@ namespace fastllm { } else if (version == 2) { std::string inputRMSWeightName = "transformer.encoder.layers." + std::to_string(i) + ".input_layernorm.weight"; - RMSNorm(hiddenStates, weight[inputRMSWeightName], 1e-5, attenInput); + RMSNorm(hiddenStates, weight[inputRMSWeightName], this->layernorm_epsilon, attenInput); } std::string qkvWeightName = weightPre + std::to_string(i) + weightMiddle + ".query_key_value.weight"; @@ -690,7 +702,7 @@ namespace fastllm { "transformer.encoder.layers." + std::to_string(i) + ".post_attention_layernorm.weight"; Data temp; Mul(hiddenStates, 1.0, temp); - RMSNorm(hiddenStates, weight[postRMSWeightName], 1e-5, mlpInput); + RMSNorm(hiddenStates, weight[postRMSWeightName], this->layernorm_epsilon, mlpInput); // 1.4 MLP std::string fcInKeyName = "transformer.encoder.layers." + std::to_string(i) + ".mlp.dense_h_to_4h"; std::string fcOutKeyName = "transformer.encoder.layers." + std::to_string(i) + ".mlp.dense_4h_to_h"; @@ -711,7 +723,7 @@ namespace fastllm { weight["transformer.final_layernorm.bias"], -1, hiddenStates); Linear(hiddenStates, weight["lm_head.weight"], Data(), logits); } else { - RMSNorm(hiddenStates, weight["transformer.encoder.final_layernorm.weight"], 1e-5, hiddenStates); + RMSNorm(hiddenStates, weight["transformer.encoder.final_layernorm.weight"], this->layernorm_epsilon, hiddenStates); Linear(hiddenStates, weight["transformer.output_layer.weight"], Data(), logits); } ToDataType(logits, DataType::FLOAT32);