Skip to content

Commit

Permalink
支持chatglm4
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jun 5, 2024
1 parent 1475ab4 commit 5d8daf1
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 13 deletions.
4 changes: 4 additions & 0 deletions include/models/chatglm.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,14 @@ namespace fastllm {
void UpdateRotaryPosEmb(float rope_factor);

int gmask_token_id;

std::string tokenizerClass = "";
private:
virtual void CausalMask(Data &data, int start) {}; // 因果mask?

float rope_factor = 1.0f;

float layernorm_epsilon = 1e-5;
};
}

Expand Down
16 changes: 15 additions & 1 deletion src/fastllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1289,7 +1289,21 @@ namespace fastllm {
return Data (DataType::FLOAT32, {1, (int)v.size()}, v);
} else if (this->type == TokenizerType::QWEN) {
std::map<std::string, int> specialTokens = {{"<|im_start|>", 151644}, {"<|im_end|>", 151645}, {"<|endoftext|>", 151643}};

for (int i = 0; i < ori.size(); i++) {
if (i + 3 < ori.size() && ori[i] == '<' && ori[i + 1] == 'F' && ori[i + 2] == 'L' && ori[i + 3] == 'M') {
if (i + 15 < ori.size() && ori.substr(i, 15) == "<FLM_FIX_TOKEN_") {
i += 15;
int now = 0;
while (ori[i] >= '0' && ori[i] <= '9') {
now = now * 10 + ori[i] - '0';
i++;
}
specialTokens["<FLM_FIX_TOKEN_" + std::to_string(now) + ">"] = now;
continue;
}
}
}

// comment these special tokens for now
// for (int i = 0; i < 205; i++) {
// specialTokens.insert("<|extra_" + std::to_string(i) + "|>");
Expand Down
104 changes: 100 additions & 4 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,65 @@ namespace fastllm {
}
};

std::string Base64Decode(const std::string &encoded) {
static const std::string base64_chars =
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789+/";
int in_len = encoded.size();
int i = 0, j = 0, in_ = 0;
char char_array_4[4], char_array_3[3];
std::string ret = "";

while (in_len-- && ( encoded[in_] != '=')) {
char_array_4[i++] = encoded[in_]; in_++;
if (i == 4) {
for (i = 0; i < 4; i++)
char_array_4[i] = base64_chars.find(char_array_4[i]);
char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
for (i = 0; (i < 3); i++)
ret.push_back(char_array_3[i]);
i = 0;
}
}

if (i) {
for (j = i; j < 4; j++)
char_array_4[j] = 0;

for (j = 0; j < 4; j++)
char_array_4[j] = base64_chars.find(char_array_4[j]);

char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];

for (j = 0; (j < i - 1); j++) ret.push_back(char_array_3[j]);
}

return ret;
}

void SplitString(const std::string &str, const std::set <char> &chars, std::vector <std::string> &ret) {
ret.clear();
std::string now = "";
for (int i = 0; i < str.size(); i++) {
if (chars.find(str[i]) == chars.end()) {
now += str[i];
} else {
if (now != "") {
ret.push_back(now);
now = "";
}
}
}
if (now != "") {
ret.push_back(now);
}
}

// 从hf文件夹读取,仅支持safetensor格式的模型
std::unique_ptr <basellm> CreateLLMModelFromHF(const std::string &modelPath,
DataType linearDataType, int groupCnt) {
Expand All @@ -314,12 +373,16 @@ namespace fastllm {
}

// 1. 检查是否有 model.safetensors.index.json,如果有就读取
std::set <std::string> stFiles;
std::string stIndexFile = path + "model.safetensors.index.json";
std::string error;
auto stIndex = json11::Json::parse(ReadAllFile(stIndexFile), error)["weight_map"];
std::set <std::string> stFiles;
for (auto it : stIndex.object_items()) {
stFiles.insert(path + it.second.string_value());
if (access(stIndexFile.c_str(), R_OK) != 0) {
stFiles.insert(path + "model.safetensors");
} else {
auto stIndex = json11::Json::parse(ReadAllFile(stIndexFile), error)["weight_map"];
for (auto it : stIndex.object_items()) {
stFiles.insert(path + it.second.string_value());
}
}
SafeTensors safeTensors(stFiles);

Expand Down Expand Up @@ -355,6 +418,39 @@ namespace fastllm {
tokenizer["decoder"]["type"].string_value() == "ByteLevel") {
model->weight.tokenizer.byteAsChar = true;
}
} else if (tokenizerClass == "ChatGLM4Tokenizer") {
// GLM4御用的分词
model->bot_role = " ";
std::vector <std::string> lines, line;
SplitString(ReadAllFile(path + "tokenizer.model"), {'\r', '\n'}, lines);
for (int i = 0; i < lines.size(); i++) {
SplitString(lines[i], {' '}, line);
model->weight.AddTokenizerWord(Base64Decode(line[0]), atoi(line[1].c_str()), 1.0f);
}
std::map<std::string, int> spTokens;
for (auto &it : tokenizerConfig["added_tokens_decoder"].object_items()) {
spTokens[it.second["content"].string_value()] = atoi(it.first.c_str());
}
model->weight.tokenizer.SetSpecialTokens(spTokens);
((ChatGLMModel*)model)->gmask_token_id = model->weight.tokenizer.GetTokenId("[gMASK]");
((ChatGLMModel*)model)->bos_token_id = model->weight.tokenizer.GetTokenId("<sop>");
((ChatGLMModel*)model)->tokenizerClass = tokenizerClass;

// 设置eos_token_id
if (config["eos_token_id"].is_array()) {
for (auto &it : config["eos_token_id"].array_items()) {
model->eos_token_ids.insert(it.int_value());
}
} else {
model->eos_token_id = config["eos_token_id"].int_value();
}

// ChatGLM采用拼接token的方法,需要强行指定分割词的TokenID
model->pre_prompt = "";
model->user_role = ("<FLM_FIX_TOKEN_" + std::to_string(model->weight.tokenizer.GetTokenId("<|user|>")) + ">\n");
model->bot_role = ("<FLM_FIX_TOKEN_" + std::to_string(model->weight.tokenizer.GetTokenId("<|assistant|>")) + ">");
model->history_sep = "";
model->weight.tokenizer.type = Tokenizer::TokenizerType::QWEN;
} else {
ErrorInFastLLM("Unsupport tokenizer_class: " + tokenizerClass);
}
Expand Down
28 changes: 20 additions & 8 deletions src/models/chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ namespace fastllm {
this->UpdateRotaryPosEmb(1.0f);
weight.embeddingNames.insert("transformer.word_embeddings.weight");
weight.embeddingNames.insert("transformer.embedding.word_embeddings.weight");

weight.linearNames = {
"*.query_key_value.weight", "*.dense.weight",
"*.mlp.dense_h_to_4h.weight", "*.mlp.dense_4h_to_h.weight",
"lm_head.weight", "transformer.output_layer.weight"
};
}

void ChatGLMModel::InitParams() {
Expand All @@ -77,13 +83,19 @@ namespace fastllm {
if (this->weight.dicts.find("gmask_token_id") != this->weight.dicts.end()) {
this->gmask_token_id = atoi(this->weight.dicts["gmask_token_id"].c_str());
}
} else if (GetVersion() == 2) {
} else if (GetVersion() == 2 && this->tokenizerClass != "ChatGLM4Tokenizer") {
this->gmask_token_id = 64790;
this->bos_token_id = 64792;
}
if (this->weight.dicts.find("rope_ratio") != this->weight.dicts.end()) {
if (this->weight.dicts.find("rope_ratio") != this->weight.dicts.end()) {
UpdateRotaryPosEmb(atof(this->weight.dicts["rope_ratio"].c_str()));
}
if (this->weight.dicts.find("layernorm_epsilon") != this->weight.dicts.end()) {
this->layernorm_epsilon = atof(this->weight.dicts["layernorm_epsilon"].c_str());
}
if (this->weight.dicts.find("seq_length") != this->weight.dicts.end()) {
max_positions = atoi(this->weight.dicts["seq_length"].c_str());
}
}

int ChatGLMModel::Forward(const fastllm::Data &inputIds, const fastllm::Data &attentionMask,
Expand Down Expand Up @@ -143,7 +155,7 @@ namespace fastllm {
} else if (version == 2) {
std::string inputRMSWeightName =
"transformer.encoder.layers." + std::to_string(i) + ".input_layernorm.weight";
RMSNorm(hiddenStates, weight[inputRMSWeightName], 1e-5, attenInput);
RMSNorm(hiddenStates, weight[inputRMSWeightName], layernorm_epsilon, attenInput);
}
std::string qkvWeightName = weightPre + std::to_string(i) + weightMiddle + ".query_key_value.weight";
std::string qkvBiasName = weightPre + std::to_string(i) + weightMiddle + ".query_key_value.bias";
Expand Down Expand Up @@ -291,7 +303,7 @@ namespace fastllm {
std::string postRMSWeightName =
"transformer.encoder.layers." + std::to_string(i) + ".post_attention_layernorm.weight";
Mul(hiddenStates, 1.0, temp);
RMSNorm(hiddenStates, weight[postRMSWeightName], 1e-5, mlpInput);
RMSNorm(hiddenStates, weight[postRMSWeightName], this->layernorm_epsilon, mlpInput);
// 1.4 MLP
std::string fcInKeyName = "transformer.encoder.layers." + std::to_string(i) + ".mlp.dense_h_to_4h";
std::string fcOutKeyName = "transformer.encoder.layers." + std::to_string(i) + ".mlp.dense_4h_to_h";
Expand Down Expand Up @@ -325,7 +337,7 @@ namespace fastllm {
weight["transformer.final_layernorm.bias"], -1, hiddenStates);
Linear(hiddenStates, weight["lm_head.weight"], Data(), logits);
} else {
RMSNorm(hiddenStates, weight["transformer.encoder.final_layernorm.weight"], 1e-5, hiddenStates);
RMSNorm(hiddenStates, weight["transformer.encoder.final_layernorm.weight"], this->layernorm_epsilon, hiddenStates);
Linear(hiddenStates, weight["transformer.output_layer.weight"], Data(), logits);
}

Expand Down Expand Up @@ -434,7 +446,7 @@ namespace fastllm {
} else if (version == 2) {
std::string inputRMSWeightName =
"transformer.encoder.layers." + std::to_string(i) + ".input_layernorm.weight";
RMSNorm(hiddenStates, weight[inputRMSWeightName], 1e-5, attenInput);
RMSNorm(hiddenStates, weight[inputRMSWeightName], this->layernorm_epsilon, attenInput);
}

std::string qkvWeightName = weightPre + std::to_string(i) + weightMiddle + ".query_key_value.weight";
Expand Down Expand Up @@ -690,7 +702,7 @@ namespace fastllm {
"transformer.encoder.layers." + std::to_string(i) + ".post_attention_layernorm.weight";
Data temp;
Mul(hiddenStates, 1.0, temp);
RMSNorm(hiddenStates, weight[postRMSWeightName], 1e-5, mlpInput);
RMSNorm(hiddenStates, weight[postRMSWeightName], this->layernorm_epsilon, mlpInput);
// 1.4 MLP
std::string fcInKeyName = "transformer.encoder.layers." + std::to_string(i) + ".mlp.dense_h_to_4h";
std::string fcOutKeyName = "transformer.encoder.layers." + std::to_string(i) + ".mlp.dense_4h_to_h";
Expand All @@ -711,7 +723,7 @@ namespace fastllm {
weight["transformer.final_layernorm.bias"], -1, hiddenStates);
Linear(hiddenStates, weight["lm_head.weight"], Data(), logits);
} else {
RMSNorm(hiddenStates, weight["transformer.encoder.final_layernorm.weight"], 1e-5, hiddenStates);
RMSNorm(hiddenStates, weight["transformer.encoder.final_layernorm.weight"], this->layernorm_epsilon, hiddenStates);
Linear(hiddenStates, weight["transformer.output_layer.weight"], Data(), logits);
}
ToDataType(logits, DataType::FLOAT32);
Expand Down

0 comments on commit 5d8daf1

Please sign in to comment.