From 5c282eb5d6d0e3004092d0fa350b31c54338825e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Wed, 14 Aug 2024 14:32:13 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E8=B0=83=E5=BA=A6=E7=AD=96?= =?UTF-8?q?=E7=95=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/models/basellm.cpp | 40 +++++++++++++++++++++++++++++++--------- 1 file changed, 31 insertions(+), 9 deletions(-) diff --git a/src/models/basellm.cpp b/src/models/basellm.cpp index 0468bf63..05816754 100644 --- a/src/models/basellm.cpp +++ b/src/models/basellm.cpp @@ -526,10 +526,10 @@ namespace fastllm { model->tokensLimit = maxTotalLens; int limit = maxTotalLens; - model->promptLimit = limit * 2 / 3; + model->promptLimit = limit * 3 / 4; if (model->verbose) { - printf("Fastllm KV Cache Limit: %f MB.\n", (double)kvCacheLimit / 1024 / 1024); + printf("Fastllm KV Cache Limit: %f MB.\n", (double)kvCacheLimit / 1e6); printf("Fastllm KV Cache Token limit: %d tokens.\n", maxTotalLens); printf("Fastllm Prompt Token limit: %d tokens.\n", std::min(model->max_positions, model->promptLimit)); printf("Fastllm Batch limit: %d.\n", maxBatch); @@ -567,10 +567,11 @@ namespace fastllm { int limit = maxTotalLens; int promptLimit = model->promptLimit; - int lenSum = 0; + int lenSum = 0, currentActivate = 0; for (auto &it: model->responseContextDict.dicts) { if (it.second->pastKeyValues[0].first.expansionDims.size() > 0) { lenSum += it.second->pastKeyValues[0].first.expansionDims[1]; + currentActivate++; } } @@ -579,10 +580,11 @@ namespace fastllm { if (isPrompt == 0 && seqLens.size() > 0) { continue; } +/* if (lenSum >= promptLimit && isPrompt) { continue; } - +*/ for (auto &it: model->responseContextDict.dicts) { if (it.second->isEnding) { continue; @@ -593,7 +595,8 @@ namespace fastllm { if (!isPrompt && it.second->preTokens == 0) { continue; } - if (it.second->currentTokens.size() > promptLimit) { + + if (it.second->currentTokens.size() > maxTotalLens) { it.second->isEnding = true; it.second->error = ResponseContextErrorPromptTooLong; continue; @@ -601,9 +604,14 @@ namespace fastllm { int outputLimit = it.second->generationConfig.output_token_limit; outputLimit = (outputLimit < 0 ? 128 : outputLimit); +/* if (isPrompt && lenSum + it.second->currentTokens.size() > promptLimit) { continue; } +*/ + if (isPrompt && lenSum + it.second->currentTokens.size() + (currentActivate + 1) * 256 > maxTotalLens) { + continue; + } if (!isPrompt) { if (it.second->pastKeyValues[0].first.expansionDims[1] == it.second->pastKeyValues[0].first.dims[1]) { @@ -617,6 +625,9 @@ namespace fastllm { } lenSum += predictLen; } + } else { + lenSum += it.second->currentTokens.size(); + currentActivate++; } generationConfigs.push_back(it.second->generationConfig); @@ -720,26 +731,37 @@ auto st = std::chrono::system_clock::now(); *pastKeyValue1, generationConfigs[0], tokensManager, logits[0])}; } } - - //PrintProfiler(); /*int total = 0; for (int i : seqLens) total += i; float spend = GetSpan(st, std::chrono::system_clock::now()); printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)total / spend); */ + dictLocker.lock(); + if (model->verbose) { genTokens += seqLens.size(); auto nowTime = std::chrono::system_clock::now(); float spend = GetSpan(lastRecordTime, nowTime); if (spend > 1) { - printf("Current batch: %d, Speed: %f tokens / s.\n", (int)seqLens.size(), (float)genTokens / spend); + int total = 0, alive = 0, aliveLen = 0, pending = 0; + for (auto &it: model->responseContextDict.dicts) { + if (it.second->isEnding) { + continue; + } + if (it.second->pastKeyValues[0].first.expansionDims.size() > 0) { + alive++; + aliveLen += it.second->pastKeyValues[0].first.expansionDims[1]; + } else { + pending++; + } + } + printf("alive = %d, pending = %d, contextLen = %d, Speed: %f tokens / s.\n", alive, pending, aliveLen, (int)seqLens.size(), (float)genTokens / spend); lastRecordTime = nowTime; genTokens = 0; } } - dictLocker.lock(); for (int i = 0; i < handles.size(); i++) { auto &it = *model->responseContextDict.dicts.find(handles[i]); int curRet = ret[i];