diff --git a/CMakeLists.txt b/CMakeLists.txt index 75cfaafe..fa5254b7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,10 +7,14 @@ option(USE_CUDA "use cuda" OFF) option(PY_API "python api" OFF) option(USE_MMAP "use mmap" OFF) +option(USE_SENTENCEPIECE "use sentencepiece" OFF) + message(STATUS "USE_CUDA: ${USE_CUDA}") message(STATUS "PYTHON_API: ${PY_API}") +message(STATUS "USE_SENTENCEPIECE: ${USE_SENTENCEPIECE}") + set(CMAKE_BUILD_TYPE "Release") if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") @@ -25,7 +29,7 @@ endif() message(STATUS "CMAKE_CXX_FLAGS" ${CMAKE_CXX_FLAGS}) set(FASTLLM_CXX_SOURCES src/fastllm.cpp src/device.cpp src/model.cpp src/executor.cpp src/devices/cpu/cpudevice.cpp src/devices/cpu/cpudevicebatch.cpp - src/models/chatglm.cpp src/models/moss.cpp src/models/llama.cpp src/models/qwen.cpp src/models/basellm.cpp) + src/models/chatglm.cpp src/models/moss.cpp src/models/llama.cpp src/models/qwen.cpp src/models/basellm.cpp src/models/glm.cpp) include_directories(include) include_directories(include/utils) @@ -35,6 +39,12 @@ if (USE_MMAP) add_compile_definitions(USE_MMAP) endif() +if (USE_SENTENCEPIECE) + set(CMAKE_CXX_STANDARD 17) + add_compile_definitions(USE_SENTENCEPIECE) + set(FASTLLM_LINKED_LIBS ${FASTLLM_LINKED_LIBS} sentencepiece) +endif() + if (USE_CUDA) enable_language(CUDA) add_compile_definitions(USE_CUDA) diff --git a/include/fastllm.h b/include/fastllm.h index c5a8284f..a1ea663f 100644 --- a/include/fastllm.h +++ b/include/fastllm.h @@ -19,6 +19,10 @@ #include #include "devices/cpu/cputhreadpool.h" +#ifdef USE_SENTENCEPIECE +#include +#endif + namespace fastllm { void SetDeviceMap(const std::map &deviceMap); std::map GetDeviceMap(); @@ -308,7 +312,8 @@ namespace fastllm { enum TokenizerType { BPE = 0, NORMAL = 1, - QWEN = 2 + QWEN = 2, + GLM = 3 }; struct TrieNode { @@ -359,6 +364,9 @@ namespace fastllm { std::unordered_map tokenToStringDict; std::unordered_map tokenToScoreDict; std::unordered_map stringToTokenDict; +#ifdef USE_SENTENCEPIECE + std::unique_ptr spProcessor; +#endif Tokenizer (); diff --git a/include/models/glm.h b/include/models/glm.h new file mode 100644 index 00000000..b71502a5 --- /dev/null +++ b/include/models/glm.h @@ -0,0 +1,62 @@ +// +// Created by huangyuyang on 5/11/23. +// + +#ifndef FASTLLM_GLM_H +#define FASTLLM_GLM_H + +#include "basellm.h" +#include "cmath" + +#include + +namespace fastllm { + class GLMModel: public basellm { + public: + GLMModel (); // 构造函数 + + // 推理 + virtual int Forward( + const Data &inputIds, + const Data &attentionMask, + const Data &positionIds, + std::vector > &pastKeyValues, + const GenerationConfig &generationConfig = GenerationConfig(), + const LastTokensManager &lastTokens = LastTokensManager(), + std::vector *logits = nullptr); + + std::vector ForwardBatch( + int batch, + const Data &inputIds, + const Data &attentionMask, + const Data &positionIds, + std::vector > &pastKeyValues, + const GenerationConfig &generationConfig = GenerationConfig(), + const LastTokensManager &lastTokens = LastTokensManager(), + std::vector *> *retLogits = nullptr); + + // 根据输入的tokens生成LLM推理的输入 + virtual void FillLLMInputs(std::vector > &inputTokens, + const std::map ¶ms, + Data &inputIds, Data &attentionMask, Data &positionIds); + + virtual void InitParams(); + virtual void WarmUp(); // 预热 + + virtual std::string MakeInput(const std::string &history, int round, const std::string &input); // 根据历史信息和当前输入生成prompt + + virtual std::string MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output); // 根据当前回复更新history + + private: + + float scale_attn_1; + + static constexpr int eot_token_id = 50000;//<|endoftext|> + static constexpr int cls_token_id = 50002;//[CLS] + static constexpr int mask_token_id = 50003;//[MASK] + static constexpr int smask_token_id = 50008;//[sMASK] + static constexpr int gmask_token_id = 50009;//[gMASK] + }; +} + +#endif //FASTLLM_GLM_H diff --git a/src/fastllm.cpp b/src/fastllm.cpp index 7bc4a1d3..4f0db830 100644 --- a/src/fastllm.cpp +++ b/src/fastllm.cpp @@ -933,6 +933,132 @@ namespace fastllm { } } return Data (DataType::FLOAT32, {1, (int)v.size()}, v); + } else if (this->type == TokenizerType::GLM) { + const std::map specialTokens = {{"[MASK]", 50003}, {"[sMASK]", 50008}, {"[gMASK]", 50009}}; + std::string blank = ""; + blank += 226, blank += 150, blank += 129; + std::string s = blank; + for (int i = 0; i < ori.size(); i++) { + if (ori[i] == ' ') { + if (i != 0 && ori[i - 1] != ' ') { + s += blank; + } + } else { + s += ori[i]; + } + } + std::vector v; + int findPos=0; + while(findPos=0&&(nextSpecialTokenPos<0||ind0){ +#ifdef USE_SENTENCEPIECE + if(spProcessor!=nullptr){ + std::vector ids; + spProcessor->Encode(subStr,&ids); + for(int id:ids){ + v.push_back(id); + } + }else{ +#endif + std::vector symbols; + for (int i = 0; i < subStr.size(); i++) { + int tokenId = -999999, pos = i - 1; + TrieNode *now = this->root; + for (int j = i; j < subStr.size(); j++) { + if (now->next.find(subStr[j]) != now->next.end()) { + now = now->next[subStr[j]]; + if (now->tokenId != -999999) { + tokenId = now->tokenId; + pos = j; + break; + } + } else { + break; + } + } + if (pos >= i) { + symbols.push_back(Symbol(now, (char *) subStr.data(), i, pos - i + 1, (int) symbols.size() - 1, + (int) symbols.size() + 1, -999999)); + i = pos; + } else { + symbols.push_back(Symbol(nullptr, (char *) subStr.data(), i, 0, (int) symbols.size() - 1, + (int) symbols.size() + 1, -999999)); + } + } + symbols.back().next = -1; + + std::priority_queue workQueue; + for (int i = 1; i < symbols.size(); i++) { + TryMergePairs(symbols, i - 1, i, workQueue); + } + + while (!workQueue.empty()) { + auto top = workQueue.top(); + workQueue.pop(); + if (symbols[top.l].len == 0 || symbols[top.r].len == 0 || + symbols[top.l].len + symbols[top.r].len != top.size) { + continue; + } + + for (int i = symbols[top.r].pos; i < symbols[top.r].pos + symbols[top.r].len; i++) { + symbols[top.l].node = symbols[top.l].node->next[symbols[top.r].s[i]]; + } + symbols[top.l].len += symbols[top.r].len; + symbols[top.r].len = 0; + symbols[top.l].next = symbols[top.r].next; + if (symbols[top.r].next >= 0) { + symbols[symbols[top.r].next].prev = top.l; + } + + TryMergePairs(symbols, symbols[top.l].prev, top.l, workQueue); + TryMergePairs(symbols, top.l, symbols[top.l].next, workQueue); + } + for (int i = 0; i < symbols.size(); i++) { + if (symbols[i].len > 0) { + v.push_back(symbols[i].node->tokenId); + } else if (symbols[i].node == nullptr) { + if (symbols[i].fixId != -999999) { + v.push_back(symbols[i].fixId); + } else { + // 未识别的字符 + uint8_t c = (uint8_t) (symbols[i].s[symbols[i].pos]); + std::string now = "<0x00>"; + now[3] = (c / 16 > 9 ? ('A' + c / 16 - 10) : ('0' + c / 16)); + now[4] = (c % 16 > 9 ? ('A' + c % 16 - 10) : ('0' + c % 16)); + if (stringToTokenDict.find(now) != stringToTokenDict.end()) { + v.push_back(stringToTokenDict[now]); + } + } + } + } +#ifdef USE_SENTENCEPIECE + } +#endif + } + if(nextSpecialTokenPos>=0){ + v.push_back(nextSpecialToken); + } + } + return Data (DataType::FLOAT32, {1, (int)v.size()}, v); } else if (this->type == TokenizerType::QWEN) { std::map specialTokens = {{"<|im_start|>", 151644}, {"<|im_end|>", 151645}, {"<|endoftext|>", 151643}}; @@ -1960,4 +2086,4 @@ namespace fastllm { std::map GetDeviceMap() { return defaultDeviceMap; } -} \ No newline at end of file +} diff --git a/src/model.cpp b/src/model.cpp index 61e990bc..401e5ab9 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -7,6 +7,7 @@ #include "moss.h" #include "llama.h" #include "qwen.h" +#include "glm.h" namespace fastllm { void basellm::LoadFromFile(const std::string &fileName) { @@ -16,8 +17,12 @@ namespace fastllm { void basellm::InitParams() { if (this->weight.dicts.find("bos_token_id") != this->weight.dicts.end()) { - this->bos_token_id = atoi(this->weight.dicts["bos_token_id"].c_str()); - this->eos_token_id = atoi(this->weight.dicts["eos_token_id"].c_str()); + if(this->weight.dicts["bos_token_id"]!="None"){ + this->bos_token_id = atoi(this->weight.dicts["bos_token_id"].c_str()); + } + if(this->weight.dicts["eos_token_id"]!="None"){ + this->eos_token_id = atoi(this->weight.dicts["eos_token_id"].c_str()); + } } if (this->weight.dicts.find("im_start_id") != this->weight.dicts.end()) { this->bos_token_id = atoi(this->weight.dicts["im_start_id"].c_str()); @@ -25,6 +30,8 @@ namespace fastllm { } if (this->weight.dicts.find("num_hidden_layers") != this->weight.dicts.end()) { block_cnt = atoi(this->weight.dicts["num_hidden_layers"].c_str()); + }else if (this->weight.dicts.find("num_layers") != this->weight.dicts.end()) { + block_cnt = atoi(this->weight.dicts["num_layers"].c_str()); } if (this->weight.dicts.find("hidden_size") != this->weight.dicts.end()) { embed_dim = atoi(this->weight.dicts["hidden_size"].c_str()); @@ -77,6 +84,8 @@ namespace fastllm { } else if (modelType == "qwen") { model = (basellm *) (new QWenModel()); model->weight.tokenizer.type = Tokenizer::TokenizerType::QWEN; + } else if (modelType == "glm") { + model = (basellm*)(new GLMModel()); } else { ErrorInFastLLM("Unkown model type: " + modelType); } @@ -95,4 +104,4 @@ namespace fastllm { basellm *model = CreateModelWithType(modelType); return std::unique_ptr (model); } -} \ No newline at end of file +} diff --git a/src/models/basellm.cpp b/src/models/basellm.cpp index 362a471e..18af45b5 100644 --- a/src/models/basellm.cpp +++ b/src/models/basellm.cpp @@ -777,4 +777,4 @@ printf("tot = %d\n", tot); void basellm::DisableAdapter() { adapterName = ""; } -} \ No newline at end of file +} diff --git a/src/models/glm.cpp b/src/models/glm.cpp new file mode 100644 index 00000000..a5481698 --- /dev/null +++ b/src/models/glm.cpp @@ -0,0 +1,336 @@ +// +// Created by huangyuyang on 5/11/23. +// + +#include "utils.h" + +#include "glm.h" + +#include + +#include + +#include + +#include + +#include + +#include + +#include + +#ifdef USE_CUDA +#include "fastllm-cuda.cuh" +#endif + +namespace fastllm { + + GLMModel::GLMModel() { + this->model_type = "glm"; + + this->bos_token_id = 50006;//<|startofpiece|> + this->eos_token_id = 50007;//<|endofpiece|> + + weight.embeddingNames.insert("word_embeddings.weight"); + weight.embeddingNames.insert("transformer.position_embeddings.weight"); + weight.embeddingNames.insert("transformer.block_position_embeddings.weight"); + weight.tokenizer.type=Tokenizer::GLM; + weight.tokenizer.Insert("[MASK]",mask_token_id); + weight.tokenizer.Insert("[sMASK]",smask_token_id); + weight.tokenizer.Insert("[gMASK]",gmask_token_id); + } + + int GLMModel::Forward(const fastllm::Data &inputIds, const fastllm::Data &attentionMask, + const fastllm::Data &positionIds, std::vector> &pastKeyValues, + const GenerationConfig &generationConfig, const LastTokensManager &lastTokens, + std::vector *logits) { + std::vector *> batchLogits; + batchLogits.push_back(logits); + return ForwardBatch(1, inputIds, attentionMask, positionIds, pastKeyValues, generationConfig, lastTokens, &batchLogits)[0]; + } + + std::vector GLMModel::ForwardBatch( + int batch, + const Data &inputIds, + const Data &attentionMask, + const Data &positionIds, + std::vector > &pastKeyValues, + const GenerationConfig &generationConfig, + const LastTokensManager &lastTokens, + std::vector *> *retLogits) { + int maxLen = inputIds.dims[1]; + Data attentionMask4D; + Data attnScoreAdds; + Data inputEmbeddings; + Data position_ids_1D; + Data block_position_ids_1D; + Data positionEmbeddings; + Data blockPositionEmbeddings; + Data attenInput; + Data qkv, q, k, v,q0; + Data attnScores; + Data attnProbs; + Data attnOutput; + Data contextLayer; + Data contextLayerPermute; + Data mlpInput; + Data mlpOutput; + Data middle, middle2; + Data toSave; + Data mem2; + std::vector lastRet; + // GLMBlock + std::string weightPre, weightMiddle; + weightPre = "transformer.layers."; + weightMiddle = ".attention"; + + { + Data attentionMask4D_1x; + attentionMask4D_1x.CopyFrom(attentionMask); + attentionMask4D_1x.Reshape({1,1,attentionMask.dims[0],attentionMask.dims[1]}); + std::vector masks(num_attention_heads); + for(int i=0;i one(attentionMask4D.Count(0),-65504.0); + attnScoreAdds.CopyFrom(Data(DataType::FLOAT32,attentionMask4D.dims,one)); + AddTo(attnScoreAdds,attentionMask4D,65504.0); + } + Embedding(inputIds, this->weight["word_embeddings.weight"], inputEmbeddings); + Data &hiddenStates = inputEmbeddings; + Split(positionIds,0,0,1,position_ids_1D); + Split(positionIds,0,1,2,block_position_ids_1D); + Embedding(position_ids_1D, this->weight["transformer.position_embeddings.weight"], positionEmbeddings); + AddTo(hiddenStates,positionEmbeddings); + Embedding(block_position_ids_1D, this->weight["transformer.block_position_embeddings.weight"], blockPositionEmbeddings); + AddTo(hiddenStates,blockPositionEmbeddings); + int memory_length=(pastKeyValues[0].first.dims.size()==0?0:pastKeyValues[0].first.dims.at(1)); + int query_length=hiddenStates.dims.at(1); + int new_memory_length=memory_length+query_length; + if(new_memory_length<=query_length){ + Split(hiddenStates,1,hiddenStates.dims.at(1)-new_memory_length,hiddenStates.dims.at(1),toSave); + }else{ + Split(hiddenStates,1,0,hiddenStates.dims.at(1),toSave);//Copy + } + for (int i = 0; i < block_cnt; i++) { + Data &mem=pastKeyValues[i].first; + bool hasMem=(mem.dims.size()!=0); + ApplyDeviceMap(this->deviceMap, i + 1, block_cnt); + std::string inputLNWeightName = "transformer.layers." + std::to_string(i) + ".input_layernorm.weight"; + std::string inputLNBiasName = "transformer.layers." + std::to_string(i) + ".input_layernorm.bias"; + LayerNorm(hiddenStates, weight[inputLNWeightName], weight[inputLNBiasName], -1, attenInput); + std::string qkvWeightName = weightPre + std::to_string(i) + weightMiddle + ".query_key_value.weight"; + std::string qkvBiasName = weightPre + std::to_string(i) + weightMiddle + ".query_key_value.bias"; + if(!hasMem){ + Linear(attenInput, weight[qkvWeightName], weight[qkvBiasName], qkv); + int per = qkv.dims.back() / 3; + Split(qkv, -1, 0, per, q); + Split(qkv, -1, per, per * 2, k); + Split(qkv, -1, per * 2, per * 3, v); + }else{ + LayerNorm(mem, weight[inputLNWeightName], weight[inputLNBiasName], -1, mem2); + CatDirect(mem2,attenInput,1); + Linear(mem2, weight[qkvWeightName], weight[qkvBiasName], qkv); + int per = qkv.dims.back() / 3; + Split(qkv, -1, 0, per, q0); + Split(qkv, -1, per, per * 2, k); + Split(qkv, -1, per * 2, per * 3, v); + int tLen=q0.dims.at(1); + Split(q0,1,tLen-attenInput.dims.at(1),tLen,q); + } + q.Reshape({q.dims[0], q.dims[1], num_attention_heads, -1}); + PermuteSelf(q,{0,2,1,3}); + k.Reshape({k.dims[0], k.dims[1], num_attention_heads, -1}); + //PermuteSelf(k,{0,2,1,3});// (1) + v.Reshape({v.dims[0], v.dims[1], num_attention_heads, -1}); + PermuteSelf(v,{0,2,1,3}); + //PermuteSelf(k,{0,1,2,3});// (2) + PermuteSelf(k,{0,2,3,1});// Merged (1) + (2) + MatMul(q,k,attnScores,scale_attn_1); + MulTo(attnScores,attentionMask4D); + AddTo(attnScores,attnScoreAdds); + Softmax(attnScores, attnProbs, -1); + MatMul(attnProbs,v,contextLayer); + PermuteSelf(contextLayer,{0,2,1,3}); + contextLayer.Reshape({contextLayer.dims[0],contextLayer.dims[1],embed_dim}); + std::string denseWeightName = weightPre + std::to_string(i) + weightMiddle + ".dense.weight"; + std::string denseBiasName = weightPre + std::to_string(i) + weightMiddle + ".dense.bias"; + Linear(contextLayer, weight[denseWeightName], weight[denseBiasName], attnOutput); + AddTo(hiddenStates,attnOutput); + std::string postLNWeightName = + "transformer.layers." + std::to_string(i) + ".post_attention_layernorm.weight"; + std::string postLNBiasName = + "transformer.layers." + std::to_string(i) + ".post_attention_layernorm.bias"; + LayerNorm(hiddenStates, weight[postLNWeightName], weight[postLNBiasName], -1, mlpInput); + std::string fcInKeyName = "transformer.layers." + std::to_string(i) + ".mlp.dense_h_to_4h"; + std::string fcOutKeyName = "transformer.layers." + std::to_string(i) + ".mlp.dense_4h_to_h"; + Linear(mlpInput, weight[fcInKeyName + ".weight"], weight[fcInKeyName + ".bias"], middle); + GeluNew(middle, middle); + Linear(middle, weight[fcOutKeyName + ".weight"], weight[fcOutKeyName + ".bias"], mlpOutput); + AddTo(hiddenStates,mlpOutput); + if(new_memory_length<=query_length){ + Split(toSave,1,0,toSave.dims.at(1),mem);//Copy + Split(hiddenStates,1,hiddenStates.dims.at(1)-new_memory_length,hiddenStates.dims.at(1),toSave); + }else{ + Split(mem,1,mem.dims.at(1)-new_memory_length+query_length,mem.dims.at(1),mem2); + Cat(mem2,toSave,1,mem); + Split(hiddenStates,1,0,hiddenStates.dims.at(1),toSave);//Copy + } + } + Data logits, topk; + LayerNorm(hiddenStates, weight["transformer.final_layernorm.weight"], + weight["transformer.final_layernorm.bias"], -1, hiddenStates); + Linear(hiddenStates, weight["word_embeddings.weight"], Data(), logits); + if (generationConfig.output_logits && retLogits != nullptr) { + int size = logits.dims.back(); + logits.ToDevice(DataDevice::CPU); + for (int b = 0; b < batch; b++) { + int base = (maxLen - 1) * batch + b; + (*retLogits)[b]->resize(size); + memcpy((float*)(*retLogits)[b]->data(), ((float*)logits.cpuData) + base * size, size * logits.unitSize); + } + } + if (generationConfig.IsSimpleGreedy()) { + TopK(logits, topk, 1); + topk.ToDevice(DataDevice::CPU); + for (int b = 0; b < batch; b++) { + int base = (maxLen - 1) * batch + b; + lastRet.push_back((int) (((float *) topk.cpuData)[base * 2] + 1e-3)); + } + } else if (!lastTokens.units.empty()) { + for (int b = 0; b < batch; b++) { + int base = (maxLen - 1) * batch + b; + lastRet.push_back(LLMSampling(logits, base, generationConfig, lastTokens.units[b])); + } + } + return lastRet; + } + + void GLMModel::FillLLMInputs(std::vector > &inputTokens, + const std::map ¶ms, + Data &inputIds, Data &attentionMask, Data &positionIds) { + inputIds.ToDevice(DataDevice::CPU); + attentionMask.ToDevice(DataDevice::CPU); + positionIds.ToDevice(DataDevice::CPU); + + int index = params.find("index")->second; + + if (index == 0) { + int mask_pos=-1; + for (auto &ids: inputTokens) { + bool hasMask=false; + for(unsigned int i=0;i vpids=std::vector(seqLen*2,0);//position_ids + for(int i=0;i vmask=std::vector(seqLen*seqLen,1);//attention_mask + for(int i=0;i(positionIds.cpuData); + int posLen=positionIds.dims.at(1); + std::vector newAttention(totalLen,1); + std::vector newPosition(tokenLen*2); + for(unsigned int i=0;i(tokenLen)}, inputTokens[0])); + attentionMask.CopyFrom(Data(DataType::FLOAT32, {1, totalLen}, newAttention)); + positionIds.CopyFrom(Data(DataType::FLOAT32, {2, static_cast(tokenLen)}, newPosition)); + } + } + + void GLMModel::InitParams() + { + basellm::InitParams(); + head_dim = embed_dim / num_attention_heads; + scale_attn_1 = 1.0f/sqrt(head_dim); +#ifdef USE_SENTENCEPIECE + if (this->weight.dicts.find("tokenizer_serialized") != this->weight.dicts.end()) { + const std::string &hexString=this->weight.dicts["tokenizer_serialized"]; + if(hexString.length()%2!=0){ + std::cerr << "Invalid hex string\n"; + }else{ + std::string decoded; + for(unsigned int i=0;i(); + weight.tokenizer.spProcessor->LoadFromSerializedProto(decoded); + printf("GetPieceSize=%d\n",weight.tokenizer.spProcessor->GetPieceSize()); + } + } +#endif + } + + void GLMModel::WarmUp() { +// printf("Warmup...\n"); +// Data inputIds = Data(DataType::FLOAT32, {1, 1}, {(float)bos_token_id}); +// Data attentionMask = Data(DataType::FLOAT32, {1, 1}, {0}); +// Data positionIds = Data(DataType::FLOAT32, {2, 1}, {0, 0}); + +// std::vector > pastKeyValues; +// for (int i = 0; i < block_cnt; i++) { +// pastKeyValues.push_back(std::make_pair(Data(DataType::FLOAT32), +// Data(DataType::FLOAT32))); +// } +// Forward(inputIds, attentionMask, positionIds, pastKeyValues); +// printf("finish.\n"); + } + + std::string GLMModel::MakeInput(const std::string &history, int round, const std::string &input) { + (void)history; + (void)round; + return input; + } + + std::string GLMModel::MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output) { + (void)history; + (void)round; + (void)input; + (void)output; + return std::string(""); + } +} diff --git a/tools/scripts/glm_export.py b/tools/scripts/glm_export.py new file mode 100644 index 00000000..284e3c0d --- /dev/null +++ b/tools/scripts/glm_export.py @@ -0,0 +1,133 @@ +import sys +import struct +import numpy as np +import torch +import binascii +from transformers import AutoTokenizer, AutoModel +from fastllm_pytools import torch2flm + +def glmtofile(exportPath, + model, + tokenizer = None, + dtype = "float16"): + if (dtype not in torch2flm.fastllm_data_type_dict): + print("dtype should in ", list(torch2flm.fastllm_data_type_dict.keys())) + exit(0) + + dict = model.state_dict() + fo = open(exportPath, "wb") + + # 0. version id + fo.write(struct.pack('i', 2)) + + # 0.1 model info + modelInfo = model.config.__dict__ + if model.generation_config is not None: + modelInfo.update(model.generation_config.__dict__) + if ("model_type" not in modelInfo): + print("unknown model_type.") + exit(0) + + modelInfo["tokenizer_use_score"] = "1" # 分词带分数 + modelInfo["tokenizer_serialized"]=binascii.hexlify(tokenizer.sp_model.serialized_model_proto()).decode("latin-1") # sentencepiece分词器序列化存储 + + if hasattr(model, "peft_config"): + adapter_size = len(model.peft_config) + modelInfo["peft_size"] = adapter_size + + fo.write(struct.pack('i', len(modelInfo))) + for it in modelInfo.keys(): + torch2flm.writeKeyValue(fo, str(it), str(modelInfo[it])) + + if hasattr(model, "peft_config"): + for adapter_name in model.peft_config.keys(): + adapter_dict = model.peft_config[adapter_name].__dict__ + torch2flm.writeString(fo, adapter_name) + fo.write(struct.pack('i', len(adapter_dict))) + for it in adapter_dict.keys(): + torch2flm.writeKeyValue(fo, str(it), str(adapter_dict[it])) + + # 1. vocab + if (tokenizer): + if (hasattr(tokenizer, "tokenizer")): + tokenizer = tokenizer.tokenizer + if (hasattr(tokenizer, "sp_model")): + piece_size = tokenizer.sp_model.piece_size() + fo.write(struct.pack('i', piece_size)) + for i in range(piece_size): + s = tokenizer.sp_model.id_to_piece(i).encode() + fo.write(struct.pack('i', len(s))) + for c in s: + fo.write(struct.pack('i', c)) + fo.write(struct.pack('i', i)) + fo.write(struct.pack('f', float(tokenizer.sp_model.get_score(i)))) + else: + vocab = tokenizer.get_vocab() + fo.write(struct.pack('i', len(vocab))) + for v in vocab.keys(): + s = v.encode() + fo.write(struct.pack('i', len(s))) + for c in s: + fo.write(struct.pack('i', c)) + fo.write(struct.pack('i', vocab[v])) + fo.write(struct.pack('f', 1.0)) + else: + fo.write(struct.pack('i', 0)) + + weight_type_dict = {} + module_dict = {} + for key, m in model.named_modules(): + if (isinstance(m, torch.nn.Linear)): + weight_type_dict[key + ".weight"] = "linear" + module_dict[key + ".weight"] = m + if (isinstance(m, torch.nn.Embedding)): + weight_type_dict[key] = "embedding" + + # 2. weight + fo.write(struct.pack('i', len(dict))) + tot = 0 + for key in dict: + ori_data_type = 0 + ori_np_data_type = np.float32 + cur_weight_type = 0 + if (key in weight_type_dict and weight_type_dict[key] in torch2flm.fastllm_weight_type_dict): + cur_weight_type = torch2flm.fastllm_weight_type_dict[weight_type_dict[key]] + to_data_type = 0 + if (cur_weight_type == 1): + to_data_type = torch2flm.fastllm_data_type_dict[dtype] + if (to_data_type == 7): + ori_data_type = 7 + ori_np_data_type = np.float16 + + cur = dict[key].numpy().astype(ori_np_data_type) + + if hasattr(model, "peft_config"): + weight_name = key.replace('base_model.model.', '') + fo.write(struct.pack('i', len(weight_name))) + fo.write(weight_name.encode()) + else: + fo.write(struct.pack('i', len(key))) + fo.write(key.encode()) + fo.write(struct.pack('i', len(cur.shape))) + for i in cur.shape: + fo.write(struct.pack('i', i)) + if (to_data_type == 3): + write_int8(fo, cur) + elif (to_data_type == 8): + write_int4(fo, cur) + else: + fo.write(struct.pack('i', to_data_type)) + fo.write(cur.data) + tot += 1 + print("output (", tot, "/", len(dict), end = " )\r") + print("\nfinish.") + fo.close() + +if __name__ == "__main__": + tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-large-chinese", trust_remote_code=True) + model = AutoModel.from_pretrained("THUDM/glm-large-chinese", trust_remote_code=True) + model = model.eval() + + dtype = sys.argv[2] if len(sys.argv) >= 3 else "float32" + exportPath = sys.argv[1] if len(sys.argv) >= 2 else "glm-" + dtype + ".flm" + glmtofile(exportPath, model, tokenizer, dtype = dtype)