Skip to content

Commit

Permalink
读取hf模型时递归读取config
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jul 19, 2024
1 parent ec2921c commit 9c018ed
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,17 @@ namespace fastllm {
return std::unique_ptr<fastllm::basellm> (model);
}

// 将config中的内容递归地加入model->dict中
void AddDictRecursion(basellm *model, const std::string &pre, const json11::Json &config) {
for (auto &it : config.object_items()) {
if (it.second.is_object()) {
AddDictRecursion(model, pre + it.first + ".", it.second);
} else {
model->weight.AddDict(pre + it.first, it.second.is_string() ? it.second.string_value() : it.second.dump());
}
}
}

// 从hf文件夹读取,仅支持safetensor格式的模型
std::unique_ptr <basellm> CreateLLMModelFromHF(const std::string &modelPath,
DataType linearDataType, int groupCnt, bool skipTokenizer, const std::string &modelConfig) {
Expand Down Expand Up @@ -545,9 +556,7 @@ namespace fastllm {
if (isJsonModel) {
((GraphLLMModel*)model)->graphLLMModelConfig->Init(modelConfig);
}
for (auto &it : config.object_items()) {
model->weight.AddDict(it.first, it.second.is_string() ? it.second.string_value() : it.second.dump());
}
AddDictRecursion(model, "", config);
// 设置eos_token_id
if (config["eos_token_id"].is_array()) {
for (auto &it : config["eos_token_id"].array_items()) {
Expand Down

0 comments on commit 9c018ed

Please sign in to comment.