From 4aafb944e5dfcc6698d805aacc5ba510cc4ce069 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= <huang.yuyang@think-force.com> Date: Thu, 19 Sep 2024 14:56:16 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81lora?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/model.h | 3 +- src/model.cpp | 95 +++++++++++++++++++++++++++++++++-- tools/fastllm_pytools/llm.py | 7 +-- tools/fastllm_pytools/util.py | 3 +- tools/src/pytools.cpp | 4 +- 5 files changed, 101 insertions(+), 11 deletions(-) diff --git a/include/model.h b/include/model.h index cd6557f8..45463e02 100644 --- a/include/model.h +++ b/include/model.h @@ -19,7 +19,8 @@ namespace fastllm { DataType linearDataType, int groupCnt = -1, bool skipTokenizer = false, - const std::string &modelConfig = ""); + const std::string &modelConfig = "", + const std::string &loraPath = ""); std::unique_ptr<basellm> CreateLLMTokenizerFromHF(const std::string &modelPath); } diff --git a/src/model.cpp b/src/model.cpp index 0a281d21..dda6d7c5 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -270,9 +270,13 @@ namespace fastllm { DataType srcType; if (this->dtype == "BF16") { srcType = DataType::BFLOAT16; - } else if (this->dtype == "F16") - { + } else if (this->dtype == "F16") { srcType = DataType::FLOAT16; + } else if (this->dtype == "F32") { + srcType = DataType::FLOAT32; + if (dstType != DataType::FLOAT32) { + ErrorInFastLLM("SafeTensorItem.CreateBuffer: unsupport src dtype " + this->dtype + "\n"); + } } else { ErrorInFastLLM("SafeTensorItem.CreateBuffer: unsupport src dtype " + this->dtype + "\n"); } @@ -532,9 +536,35 @@ namespace fastllm { // 从hf文件夹读取,仅支持safetensor格式的模型 std::unique_ptr <basellm> CreateLLMModelFromHF(const std::string &modelPath, - DataType linearDataType, int groupCnt, bool skipTokenizer, const std::string &modelConfig) { - bool isJsonModel = (modelConfig.size() > 0); + DataType linearDataType, int groupCnt, bool skipTokenizer, const std::string &modelConfig, + const std::string &loraPath) { + std::map <std::string, std::pair <std::string, std::string> > loraDicts; + SafeTensors *loraTensors = nullptr; + float loraScaling; + if (loraPath != "") { + std::string path = loraPath; + if (path.back() != '/' || path.back() != '\\') { + path += "/"; + } + loraTensors = new SafeTensors({path + "adapter_model.safetensors"}); + for (auto &it : loraTensors->GetSortedItemNames()) { + if (it.size() >= 31 && + it.substr(0, 17) == "base_model.model." && + (it.substr(it.size() - 14) == ".lora_A.weight" || it.substr(it.size() - 14) == ".lora_B.weight")) { + std::string originalName = it.substr(17, it.size() - 31) + ".weight"; + if (it.substr(it.size() - 14) == ".lora_A.weight") { + loraDicts[originalName].first = it; + } else { + loraDicts[originalName].second = it; + } + } + } + std::string loraConfigError; + auto loraConfig = json11::Json::parse(ReadAllFile(path + "adapter_config.json"), loraConfigError); + loraScaling = loraConfig["lora_alpha"].number_value() / loraConfig["r"].number_value(); + } + bool isJsonModel = (modelConfig.size() > 0); std::string path = modelPath; if (path.back() != '/' || path.back() != '\\') { path += "/"; @@ -681,6 +711,61 @@ namespace fastllm { oriDataType = DataType::FLOAT16; } tensor.CreateBuffer(oriDataType); + + if (loraDicts.find(weightName) != loraDicts.end()) { + std::string loraA = loraDicts[weightName].first; + std::string loraB = loraDicts[weightName].second; + + int inDim = loraTensors->itmeDict[loraA].intShape[1]; + int outDim = loraTensors->itmeDict[loraB].intShape[0]; + int lora = loraTensors->itmeDict[loraA].intShape[0]; + + AssertInFastLLM(loraTensors->itmeDict[loraA].dtype == "F32" && + loraTensors->itmeDict[loraB].dtype == "F32", + "Lora error: lora's dtype should be F32."); + loraTensors->itmeDict[loraA].CreateBuffer(DataType::FLOAT32); + loraTensors->itmeDict[loraB].CreateBuffer(DataType::FLOAT32); + float *weightA = (float*)loraTensors->itmeDict[loraA].buffer; + float *weightB = (float*)loraTensors->itmeDict[loraB].buffer; + + std::vector <float> loraFactor; + loraFactor.resize(inDim * outDim, 0.0f); + for (int i = 0; i < outDim; i++) { + for (int j = 0; j < lora; j++) { + for (int k = 0; k < inDim; k++) { + loraFactor[i * inDim + k] += weightB[i * lora + j] * weightA[j * inDim + k]; + } + } + } + for (int i = 0; i < loraFactor.size(); i++) { + loraFactor[i] *= loraScaling; + } + + loraTensors->itmeDict[loraA].ClearBuffer(); + loraTensors->itmeDict[loraB].ClearBuffer(); + + if (oriDataType == DataType::BFLOAT16) { + uint16_t *fp16Weight = (uint16_t*)tensor.buffer; + for (int i = 0; i < loraFactor.size(); i++) { + uint32_t now = fp16Weight[i] << 16; + float newV = ((float*)&now)[0] + loraFactor[i]; + fp16Weight[i] = ((uint32_t*)&newV)[0] >> 16; + } + } else if (oriDataType == DataType::FLOAT16) { + uint16_t *fp16Weight = (uint16_t*)tensor.buffer; + for (int i = 0; i < loraFactor.size(); i++) { + fp16Weight[i] = float_to_half(half_to_float(fp16Weight[i]) + loraFactor[i]); + } + } else if (oriDataType == DataType::FLOAT32) { + float *fp32Weight = (float*)tensor.buffer; + for (int i = 0; i < loraFactor.size(); i++) { + fp32Weight[i] = fp32Weight[i] + loraFactor[i]; + } + } else { + ErrorInFastLLM("Lora error, dtype should be float32, float16 or bfloat16."); + } + } + if (it.second == DATA_AUTO_CONV) { tensor.Transpose(oriDataType); } @@ -704,6 +789,8 @@ namespace fastllm { printf("\n"); fflush(stdout); + delete loraTensors; + model->WarmUp(); return std::unique_ptr<fastllm::basellm> (model); } diff --git a/tools/fastllm_pytools/llm.py b/tools/fastllm_pytools/llm.py index fc5f4e45..9f8bd5b9 100644 --- a/tools/fastllm_pytools/llm.py +++ b/tools/fastllm_pytools/llm.py @@ -19,7 +19,7 @@ fastllm_lib.create_llm_model.argtypes = [ctypes.c_char_p] fastllm_lib.create_llm_model.restype = ctypes.c_int -fastllm_lib.create_llm_model_fromhf.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_int, ctypes.c_bool] +fastllm_lib.create_llm_model_fromhf.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_int, ctypes.c_bool, ctypes.c_char_p] fastllm_lib.create_llm_model_fromhf.restype = ctypes.c_int fastllm_lib.create_llm_model_fromhf_with_config.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_int, ctypes.c_bool, ctypes.c_char_p] @@ -387,7 +387,8 @@ def __init__ (self, path : str, eos_token: List[str] = [], tokenizer_type = "auto", model_json: str = "", - graph: type = None): + graph: type = None, + lora: str = ""): if (graph != None): current_graph = graph() if (os.path.isdir(path) and os.path.isfile(os.path.join(path, "config.json"))): @@ -435,7 +436,7 @@ def __init__ (self, path : str, ctypes.c_bool(self.hf_tokenizer != None), model_json.encode()); else: self.model = fastllm_lib.create_llm_model_fromhf(path.encode(), fastllm_data_type_dict[dtype], int4g_groupcnt, - ctypes.c_bool(self.hf_tokenizer != None)); + ctypes.c_bool(self.hf_tokenizer != None), lora.encode()); else: print("path error: ", path); exit(0) diff --git a/tools/fastllm_pytools/util.py b/tools/fastllm_pytools/util.py index 716f071e..b08eb043 100644 --- a/tools/fastllm_pytools/util.py +++ b/tools/fastllm_pytools/util.py @@ -13,6 +13,7 @@ def make_normal_parser(des: str) -> argparse.ArgumentParser: parser.add_argument('--max_batch', type = int, default = -1, help = '每次最多同时推理的询问数量') parser.add_argument('--device', type = str, help = '使用的设备') parser.add_argument('--custom', type = str, default = "", help = '指定描述自定义模型的python文件') + parser.add_argument('--lora', type = str, default = "", help = '指定lora路径') return parser def make_normal_llm_model(args): @@ -40,7 +41,7 @@ def make_normal_llm_model(args): spec.loader.exec_module(custom_module) if (hasattr(custom_module, "__model__")): graph = getattr(custom_module, "__model__") - model = llm.model(args.path, dtype = args.dtype, graph = graph, tokenizer_type = "auto") + model = llm.model(args.path, dtype = args.dtype, graph = graph, tokenizer_type = "auto", lora = args.lora) model.set_atype(args.atype) if (args.max_batch > 0): model.set_max_batch(args.max_batch) diff --git a/tools/src/pytools.cpp b/tools/src/pytools.cpp index 4e4ba012..4f1a8671 100644 --- a/tools/src/pytools.cpp +++ b/tools/src/pytools.cpp @@ -102,10 +102,10 @@ extern "C" { return id; } - DLL_EXPORT int create_llm_model_fromhf(char *path, int dataType, int groupCnt, bool skipTokenizer) { + DLL_EXPORT int create_llm_model_fromhf(char *path, int dataType, int groupCnt, bool skipTokenizer, char *lora) { models.locker.lock(); int id = models.models.size(); - models.models[id] = fastllm::CreateLLMModelFromHF(path, (fastllm::DataType)dataType, groupCnt, skipTokenizer); + models.models[id] = fastllm::CreateLLMModelFromHF(path, (fastllm::DataType)dataType, groupCnt, skipTokenizer, "", lora); models.locker.unlock(); return id; }