Skip to content

Commit

Permalink
增加prompt cache功能
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jul 13, 2024
1 parent 7ec1d78 commit 368f8c7
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 58 deletions.
16 changes: 10 additions & 6 deletions include/models/basellm.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ namespace fastllm {
int curTokens = 0;
std::map <std::string, int> intParams;

int cacheLen = 0;

void Init(int blocks, DataType dataType);
};

Expand All @@ -50,34 +52,34 @@ namespace fastllm {
};

struct PastKVCacheMemory {
std::string prompt;
std::vector <int> inputToken;
int tokens;
int recordTimes = 0;
long long flushTime;
std::vector<std::pair<Data, Data> > kv;

PastKVCacheMemory () {}

PastKVCacheMemory (const std::string &prompt, int tokens, long long flushTime, std::vector<std::pair<Data, Data> > *kv);
PastKVCacheMemory (const std::vector <int> &prompt, int tokens, long long flushTime, std::vector<std::pair<Data, Data> > *kv);
};

struct PastKVCacheManager {
std::mutex locker;
int maxRecordNum = 5;
long long flushTime = 0;
std::map <std::string, PastKVCacheMemory*> memorys;
std::map <std::vector <int>, PastKVCacheMemory*> memorys;

// 设置最多保存的记录条数
void SetMaxRecordNum(int maxRecordNum);

// 插入一条记录,若已存在则增加引用计数
void Record(const std::string &prompt, int tokens, std::vector<std::pair<Data, Data> > *kv);
void Record(const std::vector <int> &inputToken, int tokens, std::vector<std::pair<Data, Data> > *kv);

// 尝试删除一条记录,若引用计数非0不会真的删除
void Remove(std::string prompt);
void Remove(const std::vector <int> &inputToken);

// 获取最长匹配的Memory,并加锁
PastKVCacheMemory *Get(const std::string &prompt);
PastKVCacheMemory *Get(const std::vector <int> &inputToken);

// 解锁
void Unlock();
Expand Down Expand Up @@ -173,6 +175,8 @@ namespace fastllm {

virtual void WarmUp() {}; // 预热

virtual void AddPromptCache(const std::vector <int> &inputTokens);

virtual std::string MakeInput(const std::string &history, int round, const std::string &input) = 0; // 根据历史信息和当前输入生成prompt

virtual std::string MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output) = 0; // 根据当前回复更新history
Expand Down
23 changes: 8 additions & 15 deletions src/devices/cuda/fastllm-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3332,7 +3332,7 @@ bool FastllmCudaAttention(const fastllm::Data &q, const fastllm::Data &k, const
int batch = (mask.dims.size() == 3) ? mask.dims[0] : 1;
int maskStride = (mask.dims.size() == 3 ? mask.strides[0] : mask.Count(0));

if (q1 >= 1024 || (q1 > 1 && q1 != k1)) {
if (q1 >= 1024 || (q1 > 1 && q1 != k1 && k1 >= 1024)) {
float *qk = (float *) FastllmCudaMalloc(q1 * k1 * sizeof(float));
float beta = 0, one = 1;
auto fastllmCublasHandle = getFastllmCublasHandle();
Expand Down Expand Up @@ -3463,7 +3463,7 @@ bool FastllmCudaHalfAttention(const fastllm::Data &q, const fastllm::Data &k, co
int maskStride = (mask.dims.size() == 3 ? mask.strides[0] : mask.Count(0));

half beta = __float2half_rn(0.0f), one = __float2half_rn(1.0f), hscale = __float2half_rn(scale);
if (q1 >= 1024 || (q1 > 1 && q1 != k1)) {
if (q1 >= 1024 || (q1 > 1 && q1 != k1 && k1 >= 1024)) {
int alignQ1 = q1, alignK1 = k1;
bool useFastAttn = getCudaInfos()->hasTensorCore && batch == 1 && (q2 == 128 && v2 == 128);
if (useFastAttn) {
Expand Down Expand Up @@ -3824,10 +3824,10 @@ bool DoFastllmCudaAttentionBatch(fastllm::Data **q, fastllm::Data **k, fastllm::
qk[b] = mem + memSum;
memSum += s;
}


uint8_t ** pointers = (uint8_t**)FastllmCudaMalloc(sizeof(uint8_t*) * batch * k0 * 8);
uint8_t ** cpuPointers = new uint8_t*[batch * k0 * 8];
if (true) {
uint8_t ** pointers = (uint8_t**)FastllmCudaMalloc(sizeof(uint8_t*) * batch * k0 * 8);
uint8_t ** cpuPointers = new uint8_t*[batch * k0 * 8];
for (int b = 0; b < batch; b++) {
for (int i = 0; i < k0; i++) {
cpuPointers[(b * k0 + i) * 8 + 0] = (uint8_t *) q[b]->cudaData + i * group * q[b]->dims[1] * q[b]->dims[2] * sizeof(T);
Expand All @@ -3846,14 +3846,10 @@ bool DoFastllmCudaAttentionBatch(fastllm::Data **q, fastllm::Data **k, fastllm::
} else {
FastllmMatMulTransBBatchKernel <128> <<<batch * k0, 128>>> (pointers, scale);
}
FastllmCudaFree(pointers);
delete[] cpuPointers;
}

if (true) {
int outer = q[0]->dims[0] * q[0]->dims[1];
uint8_t ** pointers = (uint8_t**)FastllmCudaMalloc(sizeof(uint8_t*) * batch * 2);
uint8_t ** cpuPointers = new uint8_t*[batch * 2];
int maxChannels = 0;
for (int b = 0; b < batch; b++) {
int outer = q[b]->dims[0] * q[b]->dims[1];
Expand All @@ -3870,13 +3866,9 @@ bool DoFastllmCudaAttentionBatch(fastllm::Data **q, fastllm::Data **k, fastllm::
} else {
FastllmSoftmaxKernelBatchInner1 <T, 128> <<<batch * outer, 128>>> (pointers, outer);
}
FastllmCudaFree(pointers);
delete[] cpuPointers;
}

if (true) {
uint8_t ** pointers = (uint8_t**)FastllmCudaMalloc(sizeof(uint8_t*) * batch * k0 * 8);
uint8_t ** cpuPointers = new uint8_t*[batch * k0 * 8];
for (int b = 0; b < batch; b++) {
for (int i = 0; i < k0; i++) {
cpuPointers[(b * k0 + i) * 8 + 0] = (uint8_t *) qk[b] + i * group * q[b]->dims[1] * k[b]->dims[1] * sizeof(T);
Expand All @@ -3896,10 +3888,11 @@ bool DoFastllmCudaAttentionBatch(fastllm::Data **q, fastllm::Data **k, fastllm::
} else {
FastllmMatMulKernel <128> <<<batch * k0, 128>>> (pointers, 1.0f);
}
FastllmCudaFree(pointers);
delete[] cpuPointers;
}

FastllmCudaFree(pointers);
delete[] cpuPointers;

FastllmCudaFree(mem);
delete[] qk;

Expand Down
107 changes: 76 additions & 31 deletions src/models/basellm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ namespace fastllm {
preTokens = 0;
}

PastKVCacheMemory::PastKVCacheMemory(const std::string &prompt, int tokens, long long flushTime, std::vector<std::pair<Data, Data> > *kv) {
this->prompt = prompt;
PastKVCacheMemory::PastKVCacheMemory(const std::vector <int> &inputToken, int tokens, long long flushTime, std::vector<std::pair<Data, Data> > *kv) {
this->inputToken = inputToken;
this->tokens = tokens;
this->flushTime = flushTime;
this->recordTimes = 1;
Expand All @@ -66,9 +66,6 @@ namespace fastllm {
this->kv.push_back(std::make_pair(Data(dataType), Data(dataType)));
}
for (int i = 0; i < kv->size(); i++) {
(*kv)[i].first.ToDevice(DataDevice::CPU);
(*kv)[i].second.ToDevice(DataDevice::CPU);

this->kv[i].first.CopyFrom((*kv)[i].first);
this->kv[i].second.CopyFrom((*kv)[i].second);
}
Expand All @@ -79,54 +76,63 @@ namespace fastllm {
this->maxRecordNum = maxRecordNum;
}

void PastKVCacheManager::Record(const std::string &prompt, int tokens, std::vector<std::pair<Data, Data> > *kv) {
void PastKVCacheManager::Record(const std::vector <int> &inputToken, int tokens, std::vector<std::pair<Data, Data> > *kv) {
std::lock_guard <std::mutex> lock(this->locker);
if (this->memorys.find(prompt) != this->memorys.end()) {
this->memorys[prompt]->recordTimes++;
this->memorys[prompt]->flushTime = ++flushTime;
if (this->memorys.find(inputToken) != this->memorys.end()) {
this->memorys[inputToken]->recordTimes++;
this->memorys[inputToken]->flushTime = ++flushTime;
return;
}

if (this->memorys.size() >= this->maxRecordNum) {
std::string prompt = "";
std::vector <int> eraseToken;
long long minFlushTime = (1LL << 60);
for (auto &it : this->memorys) {
if (it.second->flushTime < minFlushTime) {
minFlushTime = it.second->flushTime;
prompt = it.first;
eraseToken = it.first;
}
}
delete this->memorys[prompt];
this->memorys.erase(this->memorys.find(prompt));
delete this->memorys[eraseToken];
this->memorys.erase(this->memorys.find(eraseToken));
}

this->memorys[prompt] = new PastKVCacheMemory(prompt, tokens, ++flushTime, kv);
this->memorys[inputToken] = new PastKVCacheMemory(inputToken, tokens, ++flushTime, kv);
}

void PastKVCacheManager::Remove(std::string prompt) {
void PastKVCacheManager::Remove(const std::vector <int> &inputToken) {
std::lock_guard <std::mutex> lock(this->locker);
if (this->memorys.find(prompt) != this->memorys.end()) {
if ((--this->memorys[prompt]->recordTimes) <= 0) {
delete this->memorys[prompt];
this->memorys.erase(this->memorys.find(prompt));
if (this->memorys.find(inputToken) != this->memorys.end()) {
if ((--this->memorys[inputToken]->recordTimes) <= 0) {
delete this->memorys[inputToken];
this->memorys.erase(this->memorys.find(inputToken));
}
}
}

PastKVCacheMemory *PastKVCacheManager::Get(const std::string &prompt) {
locker.lock();
std::string maxPrompt = "";
PastKVCacheMemory *PastKVCacheManager::Get(const std::vector <int> &inputToken) {
std::lock_guard <std::mutex> lock(this->locker);
std::vector <int> maxToken;
for (auto &it : this->memorys) {
const std::string &cur = it.first;
if (cur.size() > maxPrompt.size() && cur.size() <= prompt.size() && prompt.substr(0, cur.size()) == cur) {
maxPrompt = cur;
const std::vector <int> &cur = it.first;
if (cur.size() > maxToken.size() && cur.size() <= inputToken.size()) {
bool match = true;
for (int i = 0; i < cur.size(); i++) {
if (inputToken[i] != cur[i]) {
match = false;
break;
}
}
if (match) {
maxToken = cur;
}
}
}
if (maxPrompt == "") {
if (maxToken.size() == 0) {
return nullptr;
}
this->memorys[maxPrompt]->flushTime = ++this->flushTime;
return this->memorys[maxPrompt];
this->memorys[maxToken]->flushTime = ++this->flushTime;
return this->memorys[maxToken];
}

void PastKVCacheManager::Unlock() {
Expand Down Expand Up @@ -542,8 +548,8 @@ namespace fastllm {
handles.push_back(it.first);

if (it.second->preTokens == 0) {
it.second->intParams["add_special_tokens"] = it.second->generationConfig.add_special_tokens;
it.second->intParams["promptLen"] = it.second->currentTokens.size();
it.second->intParams["add_special_tokens"] = it.second->cacheLen > 0 ? false : it.second->generationConfig.add_special_tokens;
it.second->intParams["promptLen"] = it.second->cacheLen + it.second->currentTokens.size();
it.second->intParams["index"] = 0;
} else {
it.second->intParams["index"]++;
Expand Down Expand Up @@ -579,7 +585,7 @@ namespace fastllm {
pastKeyValues.push_back(std::make_pair(&it.second->pastKeyValues[i].first,
&it.second->pastKeyValues[i].second));
}
if (isPrompt) {
if (isPrompt) {
cnt += it.second->currentTokens.size();

if (cnt > 1024) {
Expand Down Expand Up @@ -681,6 +687,17 @@ printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)to
context->currentTokens = inputTokens;
context->generationConfig = generationConfig;
context->tokens = LastTokensUnit(generationConfig.last_n);

auto cache = pastKVCacheManager.Get(inputTokens);
if (cache != nullptr) {
for (int i = 0; i < this->block_cnt; i++) {
context->pastKeyValues[i].first.CopyFrom(cache->kv[i].first);
context->pastKeyValues[i].second.CopyFrom(cache->kv[i].second);
}
context->currentTokens.erase(context->currentTokens.begin(), context->currentTokens.begin() + cache->inputToken.size());
context->cacheLen = cache->inputToken.size();
}

dictLocker.unlock();
dictCV.notify_one();
return handleId;
Expand Down Expand Up @@ -749,6 +766,34 @@ printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)to
}
}

void basellm::AddPromptCache(const std::vector <int> &inputTokens) {
std::unique_lock<std::mutex> dictLocker(this->dictLocker);
auto cache = pastKVCacheManager.Get(inputTokens);
if (cache != nullptr && cache->inputToken.size() == inputTokens.size()) {
return;
}
Data inputIds, attentionMask, positionIds;
std::vector<std::pair<Data, Data> > pastKeyValues;
for (int i = 0; i < block_cnt; i++) {
pastKeyValues.push_back(std::make_pair(Data(this->dataType), Data(this->dataType)));
pastKeyValues.back().first.SetKVCache();
pastKeyValues.back().second.SetKVCache();
}

int promptLen = inputTokens.size(), index = 0;
int add_special_tokens = false;
std::vector <std::vector <float> > fInputTokens;
fInputTokens.resize(1);
for (int i = 0; i < inputTokens.size(); i++) {
fInputTokens[0].push_back(inputTokens[i]);
}
FillLLMInputs(fInputTokens, {{"promptLen", promptLen}, {"index", index}, {"add_special_tokens", add_special_tokens}},
inputIds, attentionMask, positionIds);
ToDataType(attentionMask, this->dataType);
int ret = Forward(inputIds, attentionMask, positionIds, pastKeyValues);
pastKVCacheManager.Record(inputTokens, inputTokens.size(), &pastKeyValues);
}

bool basellm::NeedAttentionMask(int qlen, int klen) {
return true;
}
Expand Down
17 changes: 11 additions & 6 deletions src/models/chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,7 @@ namespace fastllm {
(ids[0] != this->gmask_token_id || ids[1] != this->bos_token_id))) {
ids.insert(ids.begin(), this->bos_token_id);
ids.insert(ids.begin(), this->gmask_token_id);
promptLen += 2;
}
}
}
Expand All @@ -733,18 +734,22 @@ namespace fastllm {
}

if (seqLen <= 1024) {
std::vector<float> vmask = std::vector<float>(seqLen * seqLen, 0);
for (int i = 0; i < seqLen - 1; i++) {
vmask[i * seqLen + seqLen - 1] = 1;
}
if (GetVersion() == 2) {
std::vector <float> vmask = std::vector <float> (seqLen * promptLen, 0);
for (int i = 0; i < seqLen; i++) {
vpids[i] = promptLen - seqLen + i;
for (int j = i + 1; j < seqLen; j++) {
vmask[i * seqLen + j] = 1;
vmask[i * promptLen + (promptLen - seqLen + j)] = 1;
}
}
attentionMask.CopyFrom(Data(DataType::FLOAT32, {seqLen, promptLen}, vmask));
} else {
std::vector<float> vmask = std::vector<float>(seqLen * seqLen, 0);
for (int i = 0; i < seqLen - 1; i++) {
vmask[i * seqLen + seqLen - 1] = 1;
}
attentionMask.CopyFrom(Data(DataType::FLOAT32, {seqLen, seqLen}, vmask));
}
attentionMask.CopyFrom(Data(DataType::FLOAT32, {seqLen, seqLen}, vmask));
} else {
attentionMask = Data();
}
Expand Down
12 changes: 12 additions & 0 deletions tools/fastllm_pytools/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
ctypes.c_int, ctypes.POINTER(ctypes.c_int)]
fastllm_lib.launch_response_llm_model.restype = ctypes.c_int

fastllm_lib.add_cache_llm_model.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_void_p]

fastllm_lib.fetch_response_llm_model.argtypes = [ctypes.c_int, ctypes.c_int]
fastllm_lib.fetch_response_llm_model.restype = ctypes.c_int

Expand Down Expand Up @@ -655,6 +657,16 @@ def stream_response(self,
res += cur;
yield res;

def add_cache(self,
prompt: str):
if (self.hf_tokenizer != None):
tokenizer = self.hf_tokenizer
input = tokenizer.encode(prompt);
fastllm_lib.add_cache_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input));
else:
print("add_cache failed: need hf_tokenizer.")
exit(0)

async def stream_response_async(self,
query: str,
history: List[Tuple[str, str]] = None,
Expand Down
9 changes: 9 additions & 0 deletions tools/src/pytools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,4 +382,13 @@ extern "C" {
}
return ret;
}

DLL_EXPORT void add_cache_llm_model(int modelId, int len, int *values) {
std::vector <int> input;
for (int i = 0; i < len; i++) {
input.push_back(values[i]);
}
auto model = models.GetModel(modelId);
model->AddPromptCache(input);
}
};

0 comments on commit 368f8c7

Please sign in to comment.