Skip to content

Commit

Permalink
支持lora
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Sep 19, 2024
1 parent c1cfdcf commit 4aafb94
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 11 deletions.
3 changes: 2 additions & 1 deletion include/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ namespace fastllm {
DataType linearDataType,
int groupCnt = -1,
bool skipTokenizer = false,
const std::string &modelConfig = "");
const std::string &modelConfig = "",
const std::string &loraPath = "");

std::unique_ptr<basellm> CreateLLMTokenizerFromHF(const std::string &modelPath);
}
Expand Down
95 changes: 91 additions & 4 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,13 @@ namespace fastllm {
DataType srcType;
if (this->dtype == "BF16") {
srcType = DataType::BFLOAT16;
} else if (this->dtype == "F16")
{
} else if (this->dtype == "F16") {
srcType = DataType::FLOAT16;
} else if (this->dtype == "F32") {
srcType = DataType::FLOAT32;
if (dstType != DataType::FLOAT32) {
ErrorInFastLLM("SafeTensorItem.CreateBuffer: unsupport src dtype " + this->dtype + "\n");
}
} else {
ErrorInFastLLM("SafeTensorItem.CreateBuffer: unsupport src dtype " + this->dtype + "\n");
}
Expand Down Expand Up @@ -532,9 +536,35 @@ namespace fastllm {

// 从hf文件夹读取,仅支持safetensor格式的模型
std::unique_ptr <basellm> CreateLLMModelFromHF(const std::string &modelPath,
DataType linearDataType, int groupCnt, bool skipTokenizer, const std::string &modelConfig) {
bool isJsonModel = (modelConfig.size() > 0);
DataType linearDataType, int groupCnt, bool skipTokenizer, const std::string &modelConfig,
const std::string &loraPath) {
std::map <std::string, std::pair <std::string, std::string> > loraDicts;
SafeTensors *loraTensors = nullptr;
float loraScaling;
if (loraPath != "") {
std::string path = loraPath;
if (path.back() != '/' || path.back() != '\\') {
path += "/";
}
loraTensors = new SafeTensors({path + "adapter_model.safetensors"});
for (auto &it : loraTensors->GetSortedItemNames()) {
if (it.size() >= 31 &&
it.substr(0, 17) == "base_model.model." &&
(it.substr(it.size() - 14) == ".lora_A.weight" || it.substr(it.size() - 14) == ".lora_B.weight")) {
std::string originalName = it.substr(17, it.size() - 31) + ".weight";
if (it.substr(it.size() - 14) == ".lora_A.weight") {
loraDicts[originalName].first = it;
} else {
loraDicts[originalName].second = it;
}
}
}
std::string loraConfigError;
auto loraConfig = json11::Json::parse(ReadAllFile(path + "adapter_config.json"), loraConfigError);
loraScaling = loraConfig["lora_alpha"].number_value() / loraConfig["r"].number_value();
}

bool isJsonModel = (modelConfig.size() > 0);
std::string path = modelPath;
if (path.back() != '/' || path.back() != '\\') {
path += "/";
Expand Down Expand Up @@ -681,6 +711,61 @@ namespace fastllm {
oriDataType = DataType::FLOAT16;
}
tensor.CreateBuffer(oriDataType);

if (loraDicts.find(weightName) != loraDicts.end()) {
std::string loraA = loraDicts[weightName].first;
std::string loraB = loraDicts[weightName].second;

int inDim = loraTensors->itmeDict[loraA].intShape[1];
int outDim = loraTensors->itmeDict[loraB].intShape[0];
int lora = loraTensors->itmeDict[loraA].intShape[0];

AssertInFastLLM(loraTensors->itmeDict[loraA].dtype == "F32" &&
loraTensors->itmeDict[loraB].dtype == "F32",
"Lora error: lora's dtype should be F32.");
loraTensors->itmeDict[loraA].CreateBuffer(DataType::FLOAT32);
loraTensors->itmeDict[loraB].CreateBuffer(DataType::FLOAT32);
float *weightA = (float*)loraTensors->itmeDict[loraA].buffer;
float *weightB = (float*)loraTensors->itmeDict[loraB].buffer;

std::vector <float> loraFactor;
loraFactor.resize(inDim * outDim, 0.0f);
for (int i = 0; i < outDim; i++) {
for (int j = 0; j < lora; j++) {
for (int k = 0; k < inDim; k++) {
loraFactor[i * inDim + k] += weightB[i * lora + j] * weightA[j * inDim + k];
}
}
}
for (int i = 0; i < loraFactor.size(); i++) {
loraFactor[i] *= loraScaling;
}

loraTensors->itmeDict[loraA].ClearBuffer();
loraTensors->itmeDict[loraB].ClearBuffer();

if (oriDataType == DataType::BFLOAT16) {
uint16_t *fp16Weight = (uint16_t*)tensor.buffer;
for (int i = 0; i < loraFactor.size(); i++) {
uint32_t now = fp16Weight[i] << 16;
float newV = ((float*)&now)[0] + loraFactor[i];
fp16Weight[i] = ((uint32_t*)&newV)[0] >> 16;
}
} else if (oriDataType == DataType::FLOAT16) {
uint16_t *fp16Weight = (uint16_t*)tensor.buffer;
for (int i = 0; i < loraFactor.size(); i++) {
fp16Weight[i] = float_to_half(half_to_float(fp16Weight[i]) + loraFactor[i]);
}
} else if (oriDataType == DataType::FLOAT32) {
float *fp32Weight = (float*)tensor.buffer;
for (int i = 0; i < loraFactor.size(); i++) {
fp32Weight[i] = fp32Weight[i] + loraFactor[i];
}
} else {
ErrorInFastLLM("Lora error, dtype should be float32, float16 or bfloat16.");
}
}

if (it.second == DATA_AUTO_CONV) {
tensor.Transpose(oriDataType);
}
Expand All @@ -704,6 +789,8 @@ namespace fastllm {
printf("\n");
fflush(stdout);

delete loraTensors;

model->WarmUp();
return std::unique_ptr<fastllm::basellm> (model);
}
Expand Down
7 changes: 4 additions & 3 deletions tools/fastllm_pytools/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
fastllm_lib.create_llm_model.argtypes = [ctypes.c_char_p]
fastllm_lib.create_llm_model.restype = ctypes.c_int

fastllm_lib.create_llm_model_fromhf.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_int, ctypes.c_bool]
fastllm_lib.create_llm_model_fromhf.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_int, ctypes.c_bool, ctypes.c_char_p]
fastllm_lib.create_llm_model_fromhf.restype = ctypes.c_int

fastllm_lib.create_llm_model_fromhf_with_config.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_int, ctypes.c_bool, ctypes.c_char_p]
Expand Down Expand Up @@ -387,7 +387,8 @@ def __init__ (self, path : str,
eos_token: List[str] = [],
tokenizer_type = "auto",
model_json: str = "",
graph: type = None):
graph: type = None,
lora: str = ""):
if (graph != None):
current_graph = graph()
if (os.path.isdir(path) and os.path.isfile(os.path.join(path, "config.json"))):
Expand Down Expand Up @@ -435,7 +436,7 @@ def __init__ (self, path : str,
ctypes.c_bool(self.hf_tokenizer != None), model_json.encode());
else:
self.model = fastllm_lib.create_llm_model_fromhf(path.encode(), fastllm_data_type_dict[dtype], int4g_groupcnt,
ctypes.c_bool(self.hf_tokenizer != None));
ctypes.c_bool(self.hf_tokenizer != None), lora.encode());
else:
print("path error: ", path);
exit(0)
Expand Down
3 changes: 2 additions & 1 deletion tools/fastllm_pytools/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def make_normal_parser(des: str) -> argparse.ArgumentParser:
parser.add_argument('--max_batch', type = int, default = -1, help = '每次最多同时推理的询问数量')
parser.add_argument('--device', type = str, help = '使用的设备')
parser.add_argument('--custom', type = str, default = "", help = '指定描述自定义模型的python文件')
parser.add_argument('--lora', type = str, default = "", help = '指定lora路径')
return parser

def make_normal_llm_model(args):
Expand Down Expand Up @@ -40,7 +41,7 @@ def make_normal_llm_model(args):
spec.loader.exec_module(custom_module)
if (hasattr(custom_module, "__model__")):
graph = getattr(custom_module, "__model__")
model = llm.model(args.path, dtype = args.dtype, graph = graph, tokenizer_type = "auto")
model = llm.model(args.path, dtype = args.dtype, graph = graph, tokenizer_type = "auto", lora = args.lora)
model.set_atype(args.atype)
if (args.max_batch > 0):
model.set_max_batch(args.max_batch)
Expand Down
4 changes: 2 additions & 2 deletions tools/src/pytools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,10 @@ extern "C" {
return id;
}

DLL_EXPORT int create_llm_model_fromhf(char *path, int dataType, int groupCnt, bool skipTokenizer) {
DLL_EXPORT int create_llm_model_fromhf(char *path, int dataType, int groupCnt, bool skipTokenizer, char *lora) {
models.locker.lock();
int id = models.models.size();
models.models[id] = fastllm::CreateLLMModelFromHF(path, (fastllm::DataType)dataType, groupCnt, skipTokenizer);
models.models[id] = fastllm::CreateLLMModelFromHF(path, (fastllm::DataType)dataType, groupCnt, skipTokenizer, "", lora);
models.locker.unlock();
return id;
}
Expand Down

0 comments on commit 4aafb94

Please sign in to comment.