Skip to content

Commit

Permalink
更新调度策略
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Aug 14, 2024
1 parent 4d1b144 commit 5c282eb
Showing 1 changed file with 31 additions and 9 deletions.
40 changes: 31 additions & 9 deletions src/models/basellm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,10 +526,10 @@ namespace fastllm {

model->tokensLimit = maxTotalLens;
int limit = maxTotalLens;
model->promptLimit = limit * 2 / 3;
model->promptLimit = limit * 3 / 4;

if (model->verbose) {
printf("Fastllm KV Cache Limit: %f MB.\n", (double)kvCacheLimit / 1024 / 1024);
printf("Fastllm KV Cache Limit: %f MB.\n", (double)kvCacheLimit / 1e6);
printf("Fastllm KV Cache Token limit: %d tokens.\n", maxTotalLens);
printf("Fastllm Prompt Token limit: %d tokens.\n", std::min(model->max_positions, model->promptLimit));
printf("Fastllm Batch limit: %d.\n", maxBatch);
Expand Down Expand Up @@ -567,10 +567,11 @@ namespace fastllm {
int limit = maxTotalLens;
int promptLimit = model->promptLimit;

int lenSum = 0;
int lenSum = 0, currentActivate = 0;
for (auto &it: model->responseContextDict.dicts) {
if (it.second->pastKeyValues[0].first.expansionDims.size() > 0) {
lenSum += it.second->pastKeyValues[0].first.expansionDims[1];
currentActivate++;
}
}

Expand All @@ -579,10 +580,11 @@ namespace fastllm {
if (isPrompt == 0 && seqLens.size() > 0) {
continue;
}
/*
if (lenSum >= promptLimit && isPrompt) {
continue;
}

*/
for (auto &it: model->responseContextDict.dicts) {
if (it.second->isEnding) {
continue;
Expand All @@ -593,17 +595,23 @@ namespace fastllm {
if (!isPrompt && it.second->preTokens == 0) {
continue;
}
if (it.second->currentTokens.size() > promptLimit) {

if (it.second->currentTokens.size() > maxTotalLens) {
it.second->isEnding = true;
it.second->error = ResponseContextErrorPromptTooLong;
continue;
}

int outputLimit = it.second->generationConfig.output_token_limit;
outputLimit = (outputLimit < 0 ? 128 : outputLimit);
/*
if (isPrompt && lenSum + it.second->currentTokens.size() > promptLimit) {
continue;
}
*/
if (isPrompt && lenSum + it.second->currentTokens.size() + (currentActivate + 1) * 256 > maxTotalLens) {
continue;
}

if (!isPrompt) {
if (it.second->pastKeyValues[0].first.expansionDims[1] == it.second->pastKeyValues[0].first.dims[1]) {
Expand All @@ -617,6 +625,9 @@ namespace fastllm {
}
lenSum += predictLen;
}
} else {
lenSum += it.second->currentTokens.size();
currentActivate++;
}

generationConfigs.push_back(it.second->generationConfig);
Expand Down Expand Up @@ -720,26 +731,37 @@ auto st = std::chrono::system_clock::now();
*pastKeyValue1, generationConfigs[0], tokensManager, logits[0])};
}
}


//PrintProfiler();
/*int total = 0;
for (int i : seqLens) total += i;
float spend = GetSpan(st, std::chrono::system_clock::now());
printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)total / spend);
*/
dictLocker.lock();

if (model->verbose) {
genTokens += seqLens.size();
auto nowTime = std::chrono::system_clock::now();
float spend = GetSpan(lastRecordTime, nowTime);
if (spend > 1) {
printf("Current batch: %d, Speed: %f tokens / s.\n", (int)seqLens.size(), (float)genTokens / spend);
int total = 0, alive = 0, aliveLen = 0, pending = 0;
for (auto &it: model->responseContextDict.dicts) {
if (it.second->isEnding) {
continue;
}
if (it.second->pastKeyValues[0].first.expansionDims.size() > 0) {
alive++;
aliveLen += it.second->pastKeyValues[0].first.expansionDims[1];
} else {
pending++;
}
}
printf("alive = %d, pending = %d, contextLen = %d, Speed: %f tokens / s.\n", alive, pending, aliveLen, (int)seqLens.size(), (float)genTokens / spend);
lastRecordTime = nowTime;
genTokens = 0;
}
}

dictLocker.lock();
for (int i = 0; i < handles.size(); i++) {
auto &it = *model->responseContextDict.dicts.find(handles[i]);
int curRet = ret[i];
Expand Down

0 comments on commit 5c282eb

Please sign in to comment.