Skip to content

Commit

Permalink
增加支持TeleAI的telechat2-3B, 7B
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Dec 13, 2024
1 parent 60871ce commit 24d73d8
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 15 deletions.
40 changes: 34 additions & 6 deletions docs/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@

目前Fastllm加载模型有以下几种方式。

* **加载后转换(两行加速模式)** (convert on-the-fly)
将原始模型加载为HuggingFace模型,再通过`from_hf()`方法,转换并加速,这种方法内存占用大且速度慢,目前不再推荐
* **直接读取** (load from Huggingface .safetensors)
直接读取HuggingFace上发布的.safetensors格式的模型(其他格式的模型可以使用transformer库导出成safetensors格式,参见[导出safetensors模型](#导出safetensors模型)

* **离线转换** (convert offline)
将原始模型转换为.flm格式的模型,一些[模型](#flm模型库)已经转换好。

* **直接读取** (load from Huggingface .safetensors)
直接读取HuggingFace上发布的模型,仅支持.safetensors格式的模型。

* **加载后转换(不推荐)**
将原始模型加载为HuggingFace模型,再通过`from_hf()`方法,转换并加速,这种方法内存占用大且速度慢,目前不再推荐。

## 支持模型一览 Model List

Expand Down Expand Up @@ -121,6 +120,11 @@

| 模型 | 加载后转换 | 离线转换 | 直接读取 |
|-----------------: |------------|------------|------------|
| microsoft/Phi-3-mini-4k-instruct | | ||
| google/gemma-2-9b | | ||
| google/gemma-2-27b | | ||
| TeleAI/TeleChat2-3B | | ||
| TeleAI/TeleChat2-7B | | ||
| fnlp/moss-moon-003-sft | []() | [](#moss模型导出) | |
| fnlp/moss-moon-003-sft-plugin | []() | [](#moss模型导出) | |
| | | | |
Expand All @@ -132,9 +136,33 @@
| openbmb/MiniCPM-2B-dpo-fp16 | [](#其它模型) | [](#minicpm模型导出) | |
| openbmb/MiniCPM3-4B | [](#其它模型) | [](#minicpm模型导出) | |
| | | | |
| microsoft/Phi-3-mini-4k-instruct | | ||


### 导出safetensors模型

通过transformers库可以将模型导出成.safetensors格式,代码如下:

``` python
# 保存这段代码为trans.py, 然后执行
# python trans.py --input 原模型地址 --output 导出.safetensors模型的地址(可以和input相同)
from transformers import AutoModelForCausalLM, AutoTokenizer

import argparse
def parse_args():
parser = argparse.ArgumentParser(description = "trans")
parser.add_argument("--input", type = str, required = True)
parser.add_argument("--output", type = str, required = True)
return parser.parse_args()

if __name__ == "__main__":
args = parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.input, trust_remote_code = True)
model = AutoModelForCausalLM.from_pretrained(args.input, device_map = "cpu",torch_dtype = "auto", trust_remote_code = True).eval()

model.save_pretrained(args.output, max_shard_size = "2048MB", safe_serialization = True)
tokenizer.save_pretrained(args.output, max_shard_size = "2048MB", safe_serialization = True)
```

### 加载后转换(两行加速模式)(convert on-the-fly)

#### ChatGLM系列
Expand Down
6 changes: 5 additions & 1 deletion src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,18 @@ namespace fastllm {
}
if (this->weight.dicts.find("num_hidden_layers") != this->weight.dicts.end()) {
block_cnt = atoi(this->weight.dicts["num_hidden_layers"].c_str());
}else if (this->weight.dicts.find("num_layers") != this->weight.dicts.end()) {
} else if (this->weight.dicts.find("num_layers") != this->weight.dicts.end()) {
block_cnt = atoi(this->weight.dicts["num_layers"].c_str());
} else if (this->weight.dicts.find("n_layer") != this->weight.dicts.end()) {
block_cnt = atoi(this->weight.dicts["n_layer"].c_str());
}
if (this->weight.dicts.find("hidden_size") != this->weight.dicts.end()) {
embed_dim = atoi(this->weight.dicts["hidden_size"].c_str());
}
if (this->weight.dicts.find("num_attention_heads") != this->weight.dicts.end()) {
num_attention_heads = atoi(this->weight.dicts["num_attention_heads"].c_str());
} else if (this->weight.dicts.find("n_head") != this->weight.dicts.end()) {
num_attention_heads = atoi(this->weight.dicts["n_head"].c_str());
}
if (this->weight.dicts.find("pre_prompt") != this->weight.dicts.end()) {
pre_prompt = this->weight.dicts["pre_prompt"];
Expand Down
31 changes: 23 additions & 8 deletions src/models/graph/telechat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,21 @@ namespace fastllm {
class TeleChatGraphModelConfig : GraphLLMModelConfig {
public:
enum TeleChatModelType {
TeleChat7B, TeleChat52B
TeleChat7B, TeleChat52B,
TeleChat2
};
TeleChatModelType teleChatModelType = TeleChatModelType::TeleChat7B;

void InitParams(GraphLLMModel *model) {
if (model->weight.dicts.find("n_positions") != model->weight.dicts.end()) {
teleChatModelType = TeleChat52B;
}

std::string error;
auto config = json11::Json::parse(model->weight.dicts["architectures"], error);
if (config.array_items()[0].string_value() == "Telechat2ForCausalLM") {
teleChatModelType = TeleChat2;
}

if (teleChatModelType == TeleChat52B) {
model->block_cnt = atoi(model->weight.dicts["n_layer"].c_str());
Expand All @@ -33,6 +40,10 @@ namespace fastllm {
model->user_role = "<_user>";
model->bot_role = "<_bot>";
model->history_sep = "";

if (teleChatModelType == TeleChat2) {
model->rope_base = 10000;
}
}
}

Expand Down Expand Up @@ -67,6 +78,11 @@ namespace fastllm {
};
std::string embeddingName = "transformer.word_embeddings.weight";
std::string logitsName = "transformer.lm_head.weight";

if (teleChatModelType == TeleChat2) {
logitsName = "lm_head.weight";
}

ret[embeddingName].push_back(std::make_pair(embeddingName, DataType::DATA_AUTO_EMBEDDING));
for (int i = 0; i < model->block_cnt; i++) {
std::string pre = "transformer.h." + std::to_string(i);
Expand All @@ -84,11 +100,6 @@ namespace fastllm {
} else {
ret[logitsName][0].second = DataType::DATA_AUTO_LINEAR;
}
if (ret.find(logitsName) == ret.end()) {
ret[embeddingName].push_back(std::make_pair(logitsName, DataType::DATA_AUTO_LINEAR));
} else {
ret[logitsName][0].second = DataType::DATA_AUTO_LINEAR;
}
}
return ret;
}
Expand Down Expand Up @@ -133,6 +144,11 @@ namespace fastllm {
OptimizeComputeGraph(graph, model->weight);
graph.Update();
} else {
std::string logitsName = "transformer.lm_head.weight";
if (teleChatModelType == TeleChat2) {
logitsName = "lm_head.weight";
}

auto &graph = *(model->GetGraph());
std::map <std::string, ComputeGraphNode> wNodes;
for (auto &it : model->weight.weight) {
Expand Down Expand Up @@ -169,8 +185,7 @@ namespace fastllm {

graph.SplitLastTokenStates(hiddenStates, seqLens, lastTokensStates);
graph.RMSNorm(lastTokensStates, wNodes["transformer.ln_f.weight"], model->rms_norm_eps, lastTokensStates);
graph.Linear(lastTokensStates, wNodes["transformer.lm_head.weight"], wNodes["transformer.lm_head.bias"], logits);

graph.Linear(lastTokensStates, wNodes[logitsName], wNodes[""], logits);
OptimizeComputeGraph(graph, model->weight);
graph.Update();
}
Expand Down

0 comments on commit 24d73d8

Please sign in to comment.