From 4aafb944e5dfcc6698d805aacc5ba510cc4ce069 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= <huang.yuyang@think-force.com>
Date: Thu, 19 Sep 2024 14:56:16 +0800
Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81lora?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 include/model.h               |  3 +-
 src/model.cpp                 | 95 +++++++++++++++++++++++++++++++++--
 tools/fastllm_pytools/llm.py  |  7 +--
 tools/fastllm_pytools/util.py |  3 +-
 tools/src/pytools.cpp         |  4 +-
 5 files changed, 101 insertions(+), 11 deletions(-)

diff --git a/include/model.h b/include/model.h
index cd6557f8..45463e02 100644
--- a/include/model.h
+++ b/include/model.h
@@ -19,7 +19,8 @@ namespace fastllm {
                                                     DataType linearDataType, 
                                                     int groupCnt = -1,
                                                     bool skipTokenizer = false,
-                                                    const std::string &modelConfig = "");
+                                                    const std::string &modelConfig = "",
+                                                    const std::string &loraPath = "");
     
     std::unique_ptr<basellm> CreateLLMTokenizerFromHF(const std::string &modelPath);
 }
diff --git a/src/model.cpp b/src/model.cpp
index 0a281d21..dda6d7c5 100644
--- a/src/model.cpp
+++ b/src/model.cpp
@@ -270,9 +270,13 @@ namespace fastllm {
             DataType srcType;
             if (this->dtype == "BF16") {
                 srcType = DataType::BFLOAT16;
-            } else if (this->dtype == "F16") 
-            {
+            } else if (this->dtype == "F16") {
                 srcType = DataType::FLOAT16;
+            } else if (this->dtype == "F32") {
+                srcType = DataType::FLOAT32;
+                if (dstType != DataType::FLOAT32) {
+                    ErrorInFastLLM("SafeTensorItem.CreateBuffer: unsupport src dtype " + this->dtype + "\n");
+                }
             } else {
                 ErrorInFastLLM("SafeTensorItem.CreateBuffer: unsupport src dtype " + this->dtype + "\n");
             }
@@ -532,9 +536,35 @@ namespace fastllm {
 
     // 从hf文件夹读取,仅支持safetensor格式的模型
     std::unique_ptr <basellm> CreateLLMModelFromHF(const std::string &modelPath, 
-                                                    DataType linearDataType, int groupCnt, bool skipTokenizer, const std::string &modelConfig) {
-        bool isJsonModel = (modelConfig.size() > 0);
+                                                    DataType linearDataType, int groupCnt, bool skipTokenizer, const std::string &modelConfig,
+                                                    const std::string &loraPath) {
+        std::map <std::string, std::pair <std::string, std::string> > loraDicts;
+        SafeTensors *loraTensors = nullptr;
+        float loraScaling;
+        if (loraPath != "") {
+            std::string path = loraPath;
+            if (path.back() != '/' || path.back() != '\\') {
+                path += "/";
+            }
+            loraTensors = new SafeTensors({path + "adapter_model.safetensors"});
+            for (auto &it : loraTensors->GetSortedItemNames()) {
+                if (it.size() >= 31 &&
+                    it.substr(0, 17) == "base_model.model." &&
+                    (it.substr(it.size() - 14) == ".lora_A.weight" || it.substr(it.size() - 14) == ".lora_B.weight")) {
+                    std::string originalName = it.substr(17, it.size() - 31) + ".weight";
+                    if (it.substr(it.size() - 14) == ".lora_A.weight") {
+                        loraDicts[originalName].first = it;
+                    } else {
+                        loraDicts[originalName].second = it;
+                    }
+                }
+            }
+            std::string loraConfigError;
+            auto loraConfig = json11::Json::parse(ReadAllFile(path + "adapter_config.json"), loraConfigError);
+            loraScaling = loraConfig["lora_alpha"].number_value() / loraConfig["r"].number_value();
+        }
 
+        bool isJsonModel = (modelConfig.size() > 0);
         std::string path = modelPath;
         if (path.back() != '/' || path.back() != '\\') {
             path += "/";
@@ -681,6 +711,61 @@ namespace fastllm {
                                 oriDataType = DataType::FLOAT16;
                             }
                             tensor.CreateBuffer(oriDataType);
+
+                            if (loraDicts.find(weightName) != loraDicts.end()) {
+                                std::string loraA = loraDicts[weightName].first;
+                                std::string loraB = loraDicts[weightName].second;
+
+                                int inDim = loraTensors->itmeDict[loraA].intShape[1];
+                                int outDim = loraTensors->itmeDict[loraB].intShape[0];
+                                int lora = loraTensors->itmeDict[loraA].intShape[0];
+
+                                AssertInFastLLM(loraTensors->itmeDict[loraA].dtype == "F32" && 
+                                                loraTensors->itmeDict[loraB].dtype == "F32", 
+                                                "Lora error: lora's dtype should be F32.");
+                                loraTensors->itmeDict[loraA].CreateBuffer(DataType::FLOAT32);
+                                loraTensors->itmeDict[loraB].CreateBuffer(DataType::FLOAT32);
+                                float *weightA = (float*)loraTensors->itmeDict[loraA].buffer;
+                                float *weightB = (float*)loraTensors->itmeDict[loraB].buffer;
+
+                                std::vector <float> loraFactor;
+                                loraFactor.resize(inDim * outDim, 0.0f);
+                                for (int i = 0; i < outDim; i++) {
+                                    for (int j = 0; j < lora; j++) {
+                                        for (int k = 0; k < inDim; k++) {
+                                            loraFactor[i * inDim + k] += weightB[i * lora + j] * weightA[j * inDim + k];
+                                        }
+                                    }
+                                }
+                                for (int i = 0; i < loraFactor.size(); i++) {
+                                    loraFactor[i] *= loraScaling;
+                                }
+
+                                loraTensors->itmeDict[loraA].ClearBuffer();
+                                loraTensors->itmeDict[loraB].ClearBuffer();
+
+                                if (oriDataType == DataType::BFLOAT16) {
+                                    uint16_t *fp16Weight = (uint16_t*)tensor.buffer;
+                                    for (int i = 0; i < loraFactor.size(); i++) {
+                                        uint32_t now = fp16Weight[i] << 16;
+                                        float newV = ((float*)&now)[0] + loraFactor[i];
+                                        fp16Weight[i] = ((uint32_t*)&newV)[0] >> 16;
+                                    }
+                                } else if (oriDataType == DataType::FLOAT16) {
+                                    uint16_t *fp16Weight = (uint16_t*)tensor.buffer;
+                                    for (int i = 0; i < loraFactor.size(); i++) {
+                                        fp16Weight[i] = float_to_half(half_to_float(fp16Weight[i]) + loraFactor[i]);
+                                    }
+                                } else if (oriDataType == DataType::FLOAT32) {
+                                    float *fp32Weight = (float*)tensor.buffer;
+                                    for (int i = 0; i < loraFactor.size(); i++) {
+                                        fp32Weight[i] = fp32Weight[i] + loraFactor[i];
+                                    }
+                                } else {
+                                    ErrorInFastLLM("Lora error, dtype should be float32, float16 or bfloat16.");
+                                }
+                            }
+
                             if (it.second == DATA_AUTO_CONV) {
                                 tensor.Transpose(oriDataType);
                             }
@@ -704,6 +789,8 @@ namespace fastllm {
         printf("\n");
         fflush(stdout);
 
+        delete loraTensors;
+
         model->WarmUp();
         return std::unique_ptr<fastllm::basellm> (model);
     }
diff --git a/tools/fastllm_pytools/llm.py b/tools/fastllm_pytools/llm.py
index fc5f4e45..9f8bd5b9 100644
--- a/tools/fastllm_pytools/llm.py
+++ b/tools/fastllm_pytools/llm.py
@@ -19,7 +19,7 @@
 fastllm_lib.create_llm_model.argtypes = [ctypes.c_char_p]
 fastllm_lib.create_llm_model.restype = ctypes.c_int
 
-fastllm_lib.create_llm_model_fromhf.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_int, ctypes.c_bool]
+fastllm_lib.create_llm_model_fromhf.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_int, ctypes.c_bool, ctypes.c_char_p]
 fastllm_lib.create_llm_model_fromhf.restype = ctypes.c_int
 
 fastllm_lib.create_llm_model_fromhf_with_config.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_int, ctypes.c_bool, ctypes.c_char_p]
@@ -387,7 +387,8 @@ def __init__ (self, path : str,
                   eos_token: List[str] = [],
                   tokenizer_type = "auto", 
                   model_json: str = "", 
-                  graph: type = None):
+                  graph: type = None, 
+                  lora: str = ""):
         if (graph != None):
             current_graph = graph()
             if (os.path.isdir(path) and os.path.isfile(os.path.join(path, "config.json"))):
@@ -435,7 +436,7 @@ def __init__ (self, path : str,
                                                                  ctypes.c_bool(self.hf_tokenizer != None), model_json.encode());
                 else:
                     self.model = fastllm_lib.create_llm_model_fromhf(path.encode(), fastllm_data_type_dict[dtype], int4g_groupcnt, 
-                                                                 ctypes.c_bool(self.hf_tokenizer != None));
+                                                                 ctypes.c_bool(self.hf_tokenizer != None), lora.encode());
             else:
                 print("path error: ", path);
                 exit(0)
diff --git a/tools/fastllm_pytools/util.py b/tools/fastllm_pytools/util.py
index 716f071e..b08eb043 100644
--- a/tools/fastllm_pytools/util.py
+++ b/tools/fastllm_pytools/util.py
@@ -13,6 +13,7 @@ def make_normal_parser(des: str) -> argparse.ArgumentParser:
     parser.add_argument('--max_batch', type = int, default = -1,  help = '每次最多同时推理的询问数量')
     parser.add_argument('--device', type = str, help = '使用的设备')
     parser.add_argument('--custom', type = str, default = "", help = '指定描述自定义模型的python文件')
+    parser.add_argument('--lora', type = str, default = "", help = '指定lora路径')
     return parser
 
 def make_normal_llm_model(args):
@@ -40,7 +41,7 @@ def make_normal_llm_model(args):
         spec.loader.exec_module(custom_module)
         if (hasattr(custom_module, "__model__")):
             graph = getattr(custom_module, "__model__")
-    model = llm.model(args.path, dtype = args.dtype, graph = graph, tokenizer_type = "auto")
+    model = llm.model(args.path, dtype = args.dtype, graph = graph, tokenizer_type = "auto", lora = args.lora)
     model.set_atype(args.atype)
     if (args.max_batch > 0):
         model.set_max_batch(args.max_batch)
diff --git a/tools/src/pytools.cpp b/tools/src/pytools.cpp
index 4e4ba012..4f1a8671 100644
--- a/tools/src/pytools.cpp
+++ b/tools/src/pytools.cpp
@@ -102,10 +102,10 @@ extern "C" {
         return id;
     }
 
-    DLL_EXPORT int create_llm_model_fromhf(char *path, int dataType, int groupCnt, bool skipTokenizer) {
+    DLL_EXPORT int create_llm_model_fromhf(char *path, int dataType, int groupCnt, bool skipTokenizer, char *lora) {
         models.locker.lock();
         int id = models.models.size();
-        models.models[id] = fastllm::CreateLLMModelFromHF(path, (fastllm::DataType)dataType, groupCnt, skipTokenizer);
+        models.models[id] = fastllm::CreateLLMModelFromHF(path, (fastllm::DataType)dataType, groupCnt, skipTokenizer, "", lora);
         models.locker.unlock();
         return id;
     }