From 8abf3079fcd34b09fc536eabdf342be9976ad22b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Tue, 4 Jun 2024 18:56:15 +0800 Subject: [PATCH] =?UTF-8?q?=E6=81=A2=E5=A4=8D=E4=B9=8B=E5=89=8D=E7=9A=84ch?= =?UTF-8?q?at=E6=A8=A1=E5=BC=8F=EF=BC=88=E5=87=8F=E5=B0=91=E6=8B=B7?= =?UTF-8?q?=E8=B4=9D=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/models/basellm.h | 4 +++ src/models/basellm.cpp | 62 ++++++++++++++++------------------------ 2 files changed, 29 insertions(+), 37 deletions(-) diff --git a/include/models/basellm.h b/include/models/basellm.h index d3125bcb..3a1e2ddc 100644 --- a/include/models/basellm.h +++ b/include/models/basellm.h @@ -213,6 +213,10 @@ namespace fastllm { PastKVCacheManager pastKVCacheManager; bool saveHistoryChat = false; + std::string lastPrompt = ""; + std::vector > *lastKeyValues = nullptr; + int lastPromptTokens = 0; + DataType dataType = DataType::FLOAT32; }; } diff --git a/src/models/basellm.cpp b/src/models/basellm.cpp index e0a38c8d..520c6ead 100644 --- a/src/models/basellm.cpp +++ b/src/models/basellm.cpp @@ -136,16 +136,22 @@ namespace fastllm { std::string basellm::Response(const std::string &oriInput, RuntimeResult retCb, const fastllm::GenerationConfig &generationConfig) { std::string input = oriInput; - PastKVCacheMemory *memory; - std::string oldPrompt; - int oldTokens = 0; if (this->saveHistoryChat) { - memory = pastKVCacheManager.Get(input); - if (memory != nullptr) { - oldPrompt = memory->prompt; - oldTokens = memory->tokens; - input = input.substr(memory->prompt.size()); + if (lastKeyValues != nullptr) { + if (input.size() < lastPrompt.size() || (input.substr(0, lastPrompt.size()) != lastPrompt)) { + lastPrompt = ""; + lastPromptTokens = 0; + delete lastKeyValues; + lastKeyValues = nullptr; + } else { + input = input.substr(lastPrompt.size()); + } } + } else { + lastPrompt = ""; + lastPromptTokens = 0; + delete lastKeyValues; + lastKeyValues = nullptr; } //printf("lastPrompt = %s\n", lastPrompt.c_str()); @@ -168,32 +174,21 @@ namespace fastllm { for (int i = 0; i < inputTokenData.Count(0); i++) { inputTokens[0].push_back(((float *) inputTokenData.cpuData)[i]); } - - std::vector > pastKeyValues; - for (int i = 0; i < block_cnt; i++) { - pastKeyValues.push_back(std::make_pair(Data(this->dataType), - Data(this->dataType))); - } - - if (this->saveHistoryChat) { - if (memory != nullptr) { - for (int i = 0; i < block_cnt; i++) { - pastKeyValues[i].first.CopyFrom(memory->kv[i].first); - pastKeyValues[i].second.CopyFrom(memory->kv[i].second); - } + + if (lastKeyValues == nullptr) { + lastKeyValues = new std::vector >(); + for (int i = 0; i < block_cnt; i++) { + lastKeyValues->push_back(std::make_pair(Data(this->dataType), Data(this->dataType))); + lastKeyValues->back().first.SetKVCache(); + lastKeyValues->back().second.SetKVCache(); } - pastKVCacheManager.Unlock(); - } - - for (int i = 0; i < block_cnt; i++) { - pastKeyValues.back().first.SetKVCache(); - pastKeyValues.back().second.SetKVCache(); } + std::vector > &pastKeyValues = (*lastKeyValues); std::string retString = ""; std::vector results; LastTokensManager tokens(1, generationConfig.last_n); - int promptLen = oldTokens + inputTokens[0].size(), index = 0; + int promptLen = lastPromptTokens + inputTokens[0].size(), index = 0; int add_special_tokens = generationConfig.add_special_tokens? 1: 0; FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}, {"add_special_tokens", add_special_tokens}}, inputIds, attentionMask, positionIds); @@ -253,15 +248,8 @@ namespace fastllm { retCb(-1, retString.c_str()); #endif - if (this->saveHistoryChat) { - std::string currentPrompt; - int currentTokens; - if (oldPrompt != "") { - pastKVCacheManager.Remove(oldPrompt); - } - pastKVCacheManager.Record(oriInput + retString, promptLen + index, &pastKeyValues); - } - + lastPrompt += (input + retString); + lastPromptTokens = promptLen + index; return retString; }