From 1d373baa2c60320d235c329750153776eb7b6060 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Fri, 31 May 2024 13:13:55 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BC=9A=E8=AF=9Dcache?= =?UTF-8?q?=EF=BC=8Cmain=E7=A8=8B=E5=BA=8F=E6=9B=B4=E6=96=B0=EF=BC=88?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E7=9B=B4=E6=8E=A5=E8=AF=BBhuggingface?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=EF=BC=8C=E6=94=AF=E6=8C=81=E8=AE=BE=E7=BD=AE?= =?UTF-8?q?system=20prompt=EF=BC=8Ceos=5Ftoken=E7=AD=89=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/models/basellm.h | 38 ++++++++++- main.cpp | 53 ++++++++++++--- src/fastllm.cpp | 12 +++- src/models/basellm.cpp | 138 +++++++++++++++++++++++++++++++-------- 4 files changed, 200 insertions(+), 41 deletions(-) diff --git a/include/models/basellm.h b/include/models/basellm.h index 85ad8f54..d3125bcb 100644 --- a/include/models/basellm.h +++ b/include/models/basellm.h @@ -48,6 +48,40 @@ namespace fastllm { void RemoveHandle(int handleId); }; + struct PastKVCacheMemory { + std::string prompt; + int tokens; + int recordTimes = 0; + long long flushTime; + std::vector > kv; + + PastKVCacheMemory () {} + + PastKVCacheMemory (const std::string &prompt, int tokens, long long flushTime, std::vector > *kv); + }; + + struct PastKVCacheManager { + std::mutex locker; + int maxRecordNum = 5; + long long flushTime = 0; + std::map memorys; + + // 设置最多保存的记录条数 + void SetMaxRecordNum(int maxRecordNum); + + // 插入一条记录,若已存在则增加引用计数 + void Record(const std::string &prompt, int tokens, std::vector > *kv); + + // 尝试删除一条记录,若引用计数非0不会真的删除 + void Remove(std::string prompt); + + // 获取最长匹配的Memory,并加锁 + PastKVCacheMemory *Get(const std::string &prompt); + + // 解锁 + void Unlock(); + }; + class basellm { public: basellm() {}; @@ -176,9 +210,7 @@ namespace fastllm { int tokensLimit = -1; - std::string lastPrompt = ""; - std::vector > *lastKeyValues = nullptr; - int lastPromptTokens = 0; + PastKVCacheManager pastKVCacheManager; bool saveHistoryChat = false; DataType dataType = DataType::FLOAT32; diff --git a/main.cpp b/main.cpp index 4d4dee73..bbc31dd6 100644 --- a/main.cpp +++ b/main.cpp @@ -1,9 +1,24 @@ #include "model.h" +std::map dataTypeDict = { + {"float32", fastllm::DataType::FLOAT32}, + {"half", fastllm::DataType::FLOAT16}, + {"float16", fastllm::DataType::FLOAT16}, + {"int8", fastllm::DataType::INT8}, + {"int4", fastllm::DataType::INT4_NOZERO}, + {"int4z", fastllm::DataType::INT4}, + {"int4g", fastllm::DataType::INT4_GROUP} +}; + struct RunConfig { std::string path = "chatglm-6b-int4.bin"; // 模型文件路径 + std::string systemPrompt = ""; + std::set eosToken; int threads = 4; // 使用的线程数 bool lowMemMode = false; // 是否使用低内存模式 + + fastllm::DataType dtype = fastllm::DataType::FLOAT16; + int groupCnt = -1; }; void Usage() { @@ -12,6 +27,9 @@ void Usage() { std::cout << "<-p|--path> : 模型文件的路径" << std::endl; std::cout << "<-t|--threads> : 使用的线程数量" << std::endl; std::cout << "<-l|--low>: 使用低内存模式" << std::endl; + std::cout << "<--system> : 设置系统提示词(system prompt)" << std::endl; + std::cout << "<--eos_token> : 设置eos token" << std::endl; + std::cout << "<--dtype> : 设置权重类型(读取hf文件时生效)" << std::endl; std::cout << "<--top_p> : 采样参数top_p" << std::endl; std::cout << "<--top_k> : 采样参数top_k" << std::endl; std::cout << "<--temperature> : 采样参数温度,越高结果越不固定" << std::endl; @@ -43,6 +61,19 @@ void ParseArgs(int argc, char **argv, RunConfig &config, fastllm::GenerationConf generationConfig.temperature = atof(sargv[++i].c_str()); } else if (sargv[i] == "--repeat_penalty") { generationConfig.repeat_penalty = atof(sargv[++i].c_str()); + } else if (sargv[i] == "--system") { + config.systemPrompt = sargv[++i]; + } else if (sargv[i] == "--eos_token") { + config.eosToken.insert(sargv[++i]); + } else if (sargv[i] == "--dtype") { + std::string dtypeStr = sargv[++i]; + if (dtypeStr.size() > 5 && dtypeStr.substr(0, 5) == "int4g") { + config.groupCnt = atoi(dtypeStr.substr(5).c_str()); + dtypeStr = dtypeStr.substr(0, 5); + } + fastllm::AssertInFastLLM(dataTypeDict.find(dtypeStr) != dataTypeDict.end(), + "Unsupport data type: " + dtypeStr); + config.dtype = dataTypeDict[dtypeStr]; } else { Usage(); exit(-1); @@ -51,9 +82,6 @@ void ParseArgs(int argc, char **argv, RunConfig &config, fastllm::GenerationConf } int main(int argc, char **argv) { - int round = 0; - std::string history = ""; - RunConfig config; fastllm::GenerationConfig generationConfig; ParseArgs(argc, argv, config, generationConfig); @@ -61,24 +89,32 @@ int main(int argc, char **argv) { fastllm::PrintInstructionInfo(); fastllm::SetThreads(config.threads); fastllm::SetLowMemMode(config.lowMemMode); - auto model = fastllm::CreateLLMModelFromFile(config.path); + bool isHFDir = access((config.path + "/config.json").c_str(), R_OK) == 0 || access((config.path + "config.json").c_str(), R_OK) == 0; + auto model = !isHFDir ? fastllm::CreateLLMModelFromFile(config.path) : fastllm::CreateLLMModelFromHF(config.path, config.dtype, config.groupCnt); model->SetSaveHistoryChat(true); + + for (auto &it : config.eosToken) { + generationConfig.stop_token_ids.insert(model->weight.tokenizer.GetTokenId(it)); + } + std::string systemConfig = config.systemPrompt; + fastllm::ChatMessages messages = {{"system", systemConfig}}; static std::string modelType = model->model_type; printf("欢迎使用 %s 模型. 输入内容对话,reset清空历史记录,stop退出程序.\n", model->model_type.c_str()); + while (true) { printf("用户: "); std::string input; std::getline(std::cin, input); if (input == "reset") { - history = ""; - round = 0; + fastllm::ChatMessages messages = {{"system", config.systemPrompt}}; continue; } if (input == "stop") { break; } - std::string ret = model->Response(model->MakeInput(history, round, input), [](int index, const char* content) { + messages.push_back(std::make_pair("user", input)); + std::string ret = model->Response(model->ApplyChatTemplate(messages), [](int index, const char* content) { if (index == 0) { printf("%s:%s", modelType.c_str(), content); fflush(stdout); @@ -91,8 +127,7 @@ int main(int argc, char **argv) { printf("\n"); } }, generationConfig); - history = model->MakeHistory(history, round, input, ret); - round++; + messages.push_back(std::make_pair("assistant", ret)); } return 0; diff --git a/src/fastllm.cpp b/src/fastllm.cpp index 45db3277..7ce8cd01 100644 --- a/src/fastllm.cpp +++ b/src/fastllm.cpp @@ -299,7 +299,7 @@ namespace fastllm { this->cacheUid = ori.cacheUid; // std::cout<<"调用拷贝构造"<dims || this->cpuData == nullptr || ori.dataType != this->dataType) { + if (ori.expansionDims != this->expansionDims || ori.dims != this->dims || this->cpuData == nullptr || ori.dataType != this->dataType) { if (ori.dims.size() == 0) { delete[] this->cpuData; this->dataType = ori.dataType; @@ -309,8 +309,14 @@ namespace fastllm { return; } this->dataType = ori.dataType; - this->Resize(ori.dims); - this->Allocate(); + if (ori.expansionDims.size() > 0 && ori.expansionDims != ori.dims) { + this->Expansion(ori.expansionDims); + this->Resize(ori.dims); + this->Allocate(); + } else { + this->Resize(ori.dims); + this->Allocate(); + } } std::memcpy(this->cpuData, ori.cpuData, this->GetBytes()); } diff --git a/src/models/basellm.cpp b/src/models/basellm.cpp index a3d71a80..ee42b278 100644 --- a/src/models/basellm.cpp +++ b/src/models/basellm.cpp @@ -55,26 +55,94 @@ namespace fastllm { isEnding = false; preTokens = 0; } + + PastKVCacheMemory::PastKVCacheMemory(const std::string &prompt, int tokens, long long flushTime, std::vector > *kv) { + this->prompt = prompt; + this->tokens = tokens; + this->flushTime = flushTime; + this->recordTimes = 1; + auto dataType = (*kv)[0].first.dataType; + for (int i = 0; i < kv->size(); i++) { + this->kv.push_back(std::make_pair(Data(dataType), Data(dataType))); + } + for (int i = 0; i < kv->size(); i++) { + this->kv[i].first.CopyFrom((*kv)[i].first); + this->kv[i].second.CopyFrom((*kv)[i].second); + } + } + + void PastKVCacheManager::SetMaxRecordNum(int maxRecordNum) { + std::lock_guard lock(this->locker); + this->maxRecordNum = maxRecordNum; + } + + void PastKVCacheManager::Record(const std::string &prompt, int tokens, std::vector > *kv) { + std::lock_guard lock(this->locker); + if (this->memorys.find(prompt) != this->memorys.end()) { + this->memorys[prompt]->recordTimes++; + this->memorys[prompt]->flushTime = ++flushTime; + return; + } + + if (this->memorys.size() >= this->maxRecordNum) { + std::string prompt = ""; + long long minFlushTime = (1LL << 60); + for (auto &it : this->memorys) { + if (it.second->flushTime < minFlushTime) { + minFlushTime = it.second->flushTime; + prompt = it.first; + } + } + delete this->memorys[prompt]; + this->memorys.erase(this->memorys.find(prompt)); + } + + this->memorys[prompt] = new PastKVCacheMemory(prompt, tokens, ++flushTime, kv); + } + + void PastKVCacheManager::Remove(std::string prompt) { + std::lock_guard lock(this->locker); + if (this->memorys.find(prompt) != this->memorys.end()) { + if ((--this->memorys[prompt]->recordTimes) <= 0) { + delete this->memorys[prompt]; + this->memorys.erase(this->memorys.find(prompt)); + } + } + } + + PastKVCacheMemory *PastKVCacheManager::Get(const std::string &prompt) { + locker.lock(); + std::string maxPrompt = ""; + for (auto &it : this->memorys) { + const std::string &cur = it.first; + if (cur.size() > maxPrompt.size() && cur.size() <= prompt.size() && prompt.substr(0, cur.size()) == cur) { + maxPrompt = cur; + } + } + if (maxPrompt == "") { + return nullptr; + } + this->memorys[maxPrompt]->flushTime = ++this->flushTime; + return this->memorys[maxPrompt]; + } + + void PastKVCacheManager::Unlock() { + locker.unlock(); + } std::string basellm::Response(const std::string &oriInput, RuntimeResult retCb, const fastllm::GenerationConfig &generationConfig) { std::string input = oriInput; + PastKVCacheMemory *memory; + std::string oldPrompt; + int oldTokens = 0; if (this->saveHistoryChat) { - if (lastKeyValues != nullptr) { - if (input.size() < lastPrompt.size() || (input.substr(0, lastPrompt.size()) != lastPrompt)) { - lastPrompt = ""; - lastPromptTokens = 0; - delete lastKeyValues; - lastKeyValues = nullptr; - } else { - input = input.substr(lastPrompt.size()); - } + memory = pastKVCacheManager.Get(input); + if (memory != nullptr) { + oldPrompt = memory->prompt; + oldTokens = memory->tokens; + input = input.substr(memory->prompt.size()); } - } else { - lastPrompt = ""; - lastPromptTokens = 0; - delete lastKeyValues; - lastKeyValues = nullptr; } //printf("lastPrompt = %s\n", lastPrompt.c_str()); @@ -97,21 +165,32 @@ namespace fastllm { for (int i = 0; i < inputTokenData.Count(0); i++) { inputTokens[0].push_back(((float *) inputTokenData.cpuData)[i]); } - - if (lastKeyValues == nullptr) { - lastKeyValues = new std::vector >(); - for (int i = 0; i < block_cnt; i++) { - lastKeyValues->push_back(std::make_pair(Data(this->dataType), Data(this->dataType))); - lastKeyValues->back().first.SetKVCache(); - lastKeyValues->back().second.SetKVCache(); + + std::vector > pastKeyValues; + for (int i = 0; i < block_cnt; i++) { + pastKeyValues.push_back(std::make_pair(Data(this->dataType), + Data(this->dataType))); + } + + if (this->saveHistoryChat) { + if (memory != nullptr) { + for (int i = 0; i < block_cnt; i++) { + pastKeyValues[i].first.CopyFrom(memory->kv[i].first); + pastKeyValues[i].second.CopyFrom(memory->kv[i].second); + } } + pastKVCacheManager.Unlock(); + } + + for (int i = 0; i < block_cnt; i++) { + pastKeyValues.back().first.SetKVCache(); + pastKeyValues.back().second.SetKVCache(); } - std::vector > &pastKeyValues = (*lastKeyValues); std::string retString = ""; std::vector results; LastTokensManager tokens(1, generationConfig.last_n); - int promptLen = lastPromptTokens + inputTokens[0].size(), index = 0; + int promptLen = oldTokens + inputTokens[0].size(), index = 0; int add_special_tokens = generationConfig.add_special_tokens? 1: 0; FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}, {"add_special_tokens", add_special_tokens}}, inputIds, attentionMask, positionIds); @@ -121,7 +200,7 @@ namespace fastllm { int ret = Forward(inputIds, attentionMask, positionIds, pastKeyValues, generationConfig, tokens); tokens.units[0].Push(ret); if (ret == eos_token_id - || generationConfig.stop_token_ids.find(index) != generationConfig.stop_token_ids.end()) { + || generationConfig.stop_token_ids.find(ret) != generationConfig.stop_token_ids.end()) { break; } @@ -171,8 +250,15 @@ namespace fastllm { retCb(-1, retString.c_str()); #endif - lastPrompt += (input + retString); - lastPromptTokens = promptLen + index; + if (this->saveHistoryChat) { + std::string currentPrompt; + int currentTokens; + if (oldPrompt != "") { + pastKVCacheManager.Remove(oldPrompt); + } + pastKVCacheManager.Record(oriInput + retString, promptLen + index, &pastKeyValues); + } + return retString; }