diff --git a/docs/models.md b/docs/models.md index cb60ef7..4052dcb 100644 --- a/docs/models.md +++ b/docs/models.md @@ -4,15 +4,14 @@ 目前Fastllm加载模型有以下几种方式。 -* **加载后转换(两行加速模式)** (convert on-the-fly) - 将原始模型加载为HuggingFace模型,再通过`from_hf()`方法,转换并加速,这种方法内存占用大且速度慢,目前不再推荐。 +* **直接读取** (load from Huggingface .safetensors) + 直接读取HuggingFace上发布的.safetensors格式的模型(其他格式的模型可以使用transformer库导出成safetensors格式,参见[导出safetensors模型](#导出safetensors模型)。 * **离线转换** (convert offline) 将原始模型转换为.flm格式的模型,一些[模型](#flm模型库)已经转换好。 -* **直接读取** (load from Huggingface .safetensors) - 直接读取HuggingFace上发布的模型,仅支持.safetensors格式的模型。 - +* **加载后转换(不推荐)** + 将原始模型加载为HuggingFace模型,再通过`from_hf()`方法,转换并加速,这种方法内存占用大且速度慢,目前不再推荐。 ## 支持模型一览 Model List @@ -121,6 +120,11 @@ | 模型 | 加载后转换 | 离线转换 | 直接读取 | |-----------------: |------------|------------|------------| +| microsoft/Phi-3-mini-4k-instruct | | | ✔ | +| google/gemma-2-9b | | | ✔ | +| google/gemma-2-27b | | | ✔ | +| TeleAI/TeleChat2-3B | | | ✔ | +| TeleAI/TeleChat2-7B | | | ✔ | | fnlp/moss-moon-003-sft | [✔]() | [✔](#moss模型导出) | | | fnlp/moss-moon-003-sft-plugin | [✔]() | [✔](#moss模型导出) | | | | | | | @@ -132,9 +136,33 @@ | openbmb/MiniCPM-2B-dpo-fp16 | [✔](#其它模型) | [✔](#minicpm模型导出) | | | openbmb/MiniCPM3-4B | [✔](#其它模型) | [✔](#minicpm模型导出) | | | | | | | -| microsoft/Phi-3-mini-4k-instruct | | | ✔ | +### 导出safetensors模型 + +通过transformers库可以将模型导出成.safetensors格式,代码如下: + +``` python +# 保存这段代码为trans.py, 然后执行 +# python trans.py --input 原模型地址 --output 导出.safetensors模型的地址(可以和input相同) +from transformers import AutoModelForCausalLM, AutoTokenizer + +import argparse +def parse_args(): + parser = argparse.ArgumentParser(description = "trans") + parser.add_argument("--input", type = str, required = True) + parser.add_argument("--output", type = str, required = True) + return parser.parse_args() + +if __name__ == "__main__": + args = parse_args() + tokenizer = AutoTokenizer.from_pretrained(args.input, trust_remote_code = True) + model = AutoModelForCausalLM.from_pretrained(args.input, device_map = "cpu",torch_dtype = "auto", trust_remote_code = True).eval() + + model.save_pretrained(args.output, max_shard_size = "2048MB", safe_serialization = True) + tokenizer.save_pretrained(args.output, max_shard_size = "2048MB", safe_serialization = True) +``` + ### 加载后转换(两行加速模式)(convert on-the-fly) #### ChatGLM系列 diff --git a/src/model.cpp b/src/model.cpp index 29a0344..5a1e69b 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -93,14 +93,18 @@ namespace fastllm { } if (this->weight.dicts.find("num_hidden_layers") != this->weight.dicts.end()) { block_cnt = atoi(this->weight.dicts["num_hidden_layers"].c_str()); - }else if (this->weight.dicts.find("num_layers") != this->weight.dicts.end()) { + } else if (this->weight.dicts.find("num_layers") != this->weight.dicts.end()) { block_cnt = atoi(this->weight.dicts["num_layers"].c_str()); + } else if (this->weight.dicts.find("n_layer") != this->weight.dicts.end()) { + block_cnt = atoi(this->weight.dicts["n_layer"].c_str()); } if (this->weight.dicts.find("hidden_size") != this->weight.dicts.end()) { embed_dim = atoi(this->weight.dicts["hidden_size"].c_str()); } if (this->weight.dicts.find("num_attention_heads") != this->weight.dicts.end()) { num_attention_heads = atoi(this->weight.dicts["num_attention_heads"].c_str()); + } else if (this->weight.dicts.find("n_head") != this->weight.dicts.end()) { + num_attention_heads = atoi(this->weight.dicts["n_head"].c_str()); } if (this->weight.dicts.find("pre_prompt") != this->weight.dicts.end()) { pre_prompt = this->weight.dicts["pre_prompt"]; diff --git a/src/models/graph/telechat.cpp b/src/models/graph/telechat.cpp index b9ca3ee..5851cda 100644 --- a/src/models/graph/telechat.cpp +++ b/src/models/graph/telechat.cpp @@ -4,7 +4,8 @@ namespace fastllm { class TeleChatGraphModelConfig : GraphLLMModelConfig { public: enum TeleChatModelType { - TeleChat7B, TeleChat52B + TeleChat7B, TeleChat52B, + TeleChat2 }; TeleChatModelType teleChatModelType = TeleChatModelType::TeleChat7B; @@ -12,6 +13,12 @@ namespace fastllm { if (model->weight.dicts.find("n_positions") != model->weight.dicts.end()) { teleChatModelType = TeleChat52B; } + + std::string error; + auto config = json11::Json::parse(model->weight.dicts["architectures"], error); + if (config.array_items()[0].string_value() == "Telechat2ForCausalLM") { + teleChatModelType = TeleChat2; + } if (teleChatModelType == TeleChat52B) { model->block_cnt = atoi(model->weight.dicts["n_layer"].c_str()); @@ -33,6 +40,10 @@ namespace fastllm { model->user_role = "<_user>"; model->bot_role = "<_bot>"; model->history_sep = ""; + + if (teleChatModelType == TeleChat2) { + model->rope_base = 10000; + } } } @@ -67,6 +78,11 @@ namespace fastllm { }; std::string embeddingName = "transformer.word_embeddings.weight"; std::string logitsName = "transformer.lm_head.weight"; + + if (teleChatModelType == TeleChat2) { + logitsName = "lm_head.weight"; + } + ret[embeddingName].push_back(std::make_pair(embeddingName, DataType::DATA_AUTO_EMBEDDING)); for (int i = 0; i < model->block_cnt; i++) { std::string pre = "transformer.h." + std::to_string(i); @@ -84,11 +100,6 @@ namespace fastllm { } else { ret[logitsName][0].second = DataType::DATA_AUTO_LINEAR; } - if (ret.find(logitsName) == ret.end()) { - ret[embeddingName].push_back(std::make_pair(logitsName, DataType::DATA_AUTO_LINEAR)); - } else { - ret[logitsName][0].second = DataType::DATA_AUTO_LINEAR; - } } return ret; } @@ -133,6 +144,11 @@ namespace fastllm { OptimizeComputeGraph(graph, model->weight); graph.Update(); } else { + std::string logitsName = "transformer.lm_head.weight"; + if (teleChatModelType == TeleChat2) { + logitsName = "lm_head.weight"; + } + auto &graph = *(model->GetGraph()); std::map wNodes; for (auto &it : model->weight.weight) { @@ -169,8 +185,7 @@ namespace fastllm { graph.SplitLastTokenStates(hiddenStates, seqLens, lastTokensStates); graph.RMSNorm(lastTokensStates, wNodes["transformer.ln_f.weight"], model->rms_norm_eps, lastTokensStates); - graph.Linear(lastTokensStates, wNodes["transformer.lm_head.weight"], wNodes["transformer.lm_head.bias"], logits); - + graph.Linear(lastTokensStates, wNodes[logitsName], wNodes[""], logits); OptimizeComputeGraph(graph, model->weight); graph.Update(); }