Skip to content

Commit

Permalink
恢复之前的chat模式(减少拷贝)
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jun 4, 2024
1 parent 5adbec0 commit 8abf307
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 37 deletions.
4 changes: 4 additions & 0 deletions include/models/basellm.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@ namespace fastllm {
PastKVCacheManager pastKVCacheManager;
bool saveHistoryChat = false;

std::string lastPrompt = "";
std::vector<std::pair<Data, Data> > *lastKeyValues = nullptr;
int lastPromptTokens = 0;

DataType dataType = DataType::FLOAT32;
};
}
Expand Down
62 changes: 25 additions & 37 deletions src/models/basellm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,16 +136,22 @@ namespace fastllm {
std::string basellm::Response(const std::string &oriInput, RuntimeResult retCb,
const fastllm::GenerationConfig &generationConfig) {
std::string input = oriInput;
PastKVCacheMemory *memory;
std::string oldPrompt;
int oldTokens = 0;
if (this->saveHistoryChat) {
memory = pastKVCacheManager.Get(input);
if (memory != nullptr) {
oldPrompt = memory->prompt;
oldTokens = memory->tokens;
input = input.substr(memory->prompt.size());
if (lastKeyValues != nullptr) {
if (input.size() < lastPrompt.size() || (input.substr(0, lastPrompt.size()) != lastPrompt)) {
lastPrompt = "";
lastPromptTokens = 0;
delete lastKeyValues;
lastKeyValues = nullptr;
} else {
input = input.substr(lastPrompt.size());
}
}
} else {
lastPrompt = "";
lastPromptTokens = 0;
delete lastKeyValues;
lastKeyValues = nullptr;
}

//printf("lastPrompt = %s\n", lastPrompt.c_str());
Expand All @@ -168,32 +174,21 @@ namespace fastllm {
for (int i = 0; i < inputTokenData.Count(0); i++) {
inputTokens[0].push_back(((float *) inputTokenData.cpuData)[i]);
}

std::vector <std::pair <Data, Data> > pastKeyValues;
for (int i = 0; i < block_cnt; i++) {
pastKeyValues.push_back(std::make_pair(Data(this->dataType),
Data(this->dataType)));
}

if (this->saveHistoryChat) {
if (memory != nullptr) {
for (int i = 0; i < block_cnt; i++) {
pastKeyValues[i].first.CopyFrom(memory->kv[i].first);
pastKeyValues[i].second.CopyFrom(memory->kv[i].second);
}

if (lastKeyValues == nullptr) {
lastKeyValues = new std::vector<std::pair<Data, Data> >();
for (int i = 0; i < block_cnt; i++) {
lastKeyValues->push_back(std::make_pair(Data(this->dataType), Data(this->dataType)));
lastKeyValues->back().first.SetKVCache();
lastKeyValues->back().second.SetKVCache();
}
pastKVCacheManager.Unlock();
}

for (int i = 0; i < block_cnt; i++) {
pastKeyValues.back().first.SetKVCache();
pastKeyValues.back().second.SetKVCache();
}

std::vector<std::pair<Data, Data> > &pastKeyValues = (*lastKeyValues);
std::string retString = "";
std::vector<float> results;
LastTokensManager tokens(1, generationConfig.last_n);
int promptLen = oldTokens + inputTokens[0].size(), index = 0;
int promptLen = lastPromptTokens + inputTokens[0].size(), index = 0;
int add_special_tokens = generationConfig.add_special_tokens? 1: 0;
FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}, {"add_special_tokens", add_special_tokens}},
inputIds, attentionMask, positionIds);
Expand Down Expand Up @@ -253,15 +248,8 @@ namespace fastllm {
retCb(-1, retString.c_str());
#endif

if (this->saveHistoryChat) {
std::string currentPrompt;
int currentTokens;
if (oldPrompt != "") {
pastKVCacheManager.Remove(oldPrompt);
}
pastKVCacheManager.Record(oriInput + retString, promptLen + index, &pastKeyValues);
}

lastPrompt += (input + retString);
lastPromptTokens = promptLen + index;
return retString;
}

Expand Down

0 comments on commit 8abf307

Please sign in to comment.