Skip to content

Commit

Permalink
异步服务增加显存控制
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jul 17, 2024
1 parent 8c92ca8 commit e18535c
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 4 deletions.
5 changes: 5 additions & 0 deletions include/models/basellm.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,11 @@ namespace fastllm {
std::vector<std::pair<Data, Data> > *lastKeyValues = nullptr;
int lastPromptTokens = 0;

long long elementsInKVCachePerToken = -1; // 每个token使用多少个元素的的KVCache
long long kvCacheLimit = -1;
int maxBatch = -1;
bool verbose = false;

DataType dataType = DataType::FLOAT32;
bool isFree = false; // 是否释放
};
Expand Down
74 changes: 70 additions & 4 deletions src/models/basellm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,40 @@ namespace fastllm {
if (mainLoop == nullptr) {
if (mainLoop == nullptr) {
mainLoop = new std::thread([](basellm *model) {
long long kvCacheLimit = 16LL << 30;
#ifdef USE_CUDA
auto freeSizes = FastllmCudaGetFreeSizes();
kvCacheLimit = 0;
for (long long i : freeSizes) {
kvCacheLimit += std::max(0LL, i - (2LL << 30));
}
#endif
if (model->kvCacheLimit > 0) {
kvCacheLimit = model->kvCacheLimit;
}

int unitSize = (model->dataType == DataType::FLOAT32 ? 4 : 2);
int maxTotalLens = kvCacheLimit / (model->elementsInKVCachePerToken * unitSize);
if (model->tokensLimit > 0) {
maxTotalLens = model->tokensLimit;
}

int maxBatch = std::max(1, std::min(512, maxTotalLens / 128));
if (model->maxBatch > 0) {
maxBatch = model->maxBatch;
}

model->tokensLimit = maxTotalLens;

model->verbose = true;
if (model->verbose) {
printf("Fastllm KV Cache Limit: %f MB.\n", (double)kvCacheLimit / 1024 / 1024);
printf("Fastllm KV Cache Token limit: %d tokens.\n", maxTotalLens);
printf("Fastllm Batch limit: %d.\n", maxBatch);
}

auto lastRecordTime = std::chrono::system_clock::now();
long long genTokens = 0;
while (true) {
if (model->isFree) {
break;
Expand All @@ -502,10 +536,12 @@ namespace fastllm {

std::unique_lock<std::mutex> dictLocker(model->dictLocker);

int limit = model->tokensLimit > 0 ? model->tokensLimit : 1e9;
int limit = maxTotalLens;
int promptLimit = limit * 2 / 3;

int lenSum = 0;
for (auto &it: model->responseContextDict.dicts) {
if (it.second->pastKeyValues[0].first.expansionDims.size() > 0 && !it.second->isEnding) {
if (it.second->pastKeyValues[0].first.expansionDims.size() > 0) {
lenSum += it.second->pastKeyValues[0].first.expansionDims[1];
}
}
Expand All @@ -515,7 +551,7 @@ namespace fastllm {
if (isPrompt == 0 && seqLens.size() > 0) {
continue;
}
if (lenSum > limit && isPrompt) {
if (lenSum >= promptLimit && isPrompt) {
continue;
}

Expand All @@ -532,10 +568,24 @@ namespace fastllm {

int outputLimit = it.second->generationConfig.output_token_limit;
outputLimit = (outputLimit < 0 ? 128 : outputLimit);
if (isPrompt && lenSum + it.second->currentTokens.size() + outputLimit > limit) {
if (isPrompt && lenSum + it.second->currentTokens.size() > promptLimit) {
continue;
}

if (!isPrompt) {
if (it.second->pastKeyValues[0].first.expansionDims[1] == it.second->pastKeyValues[0].first.dims[1]) {
int sur = it.second->generationConfig.output_token_limit - it.second->curTokens;
int predictLen = 256;
if (sur > 0) {
predictLen = std::min(predictLen, ((sur - 1) / 128 + 1) * 128);
}
if (lenSum + predictLen > limit) {
continue;
}
lenSum += predictLen;
}
}

generationConfigs.push_back(it.second->generationConfig);
if (it.second->generationConfig.output_logits) {
it.second->resultLogits.push(new std::vector<float>());
Expand Down Expand Up @@ -593,6 +643,10 @@ namespace fastllm {
}
// break;
}

if (seqLens.size() >= maxBatch || lenSum + seqLens.size() * 128 > limit) {
break;
}
}
}
if (seqLens.size() > 0) {
Expand Down Expand Up @@ -633,13 +687,25 @@ auto st = std::chrono::system_clock::now();
*pastKeyValue1, generationConfigs[0], tokensManager, logits[0])};
}
}


//PrintProfiler();
/*int total = 0;
for (int i : seqLens) total += i;
float spend = GetSpan(st, std::chrono::system_clock::now());
printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)total / spend);
*/
if (model->verbose) {
genTokens += seqLens.size();
auto nowTime = std::chrono::system_clock::now();
float spend = GetSpan(lastRecordTime, nowTime);
if (spend > 1) {
printf("Current batch: %d, Speed: %f tokens / s.\n", (int)seqLens.size(), (float)genTokens / spend);
lastRecordTime = nowTime;
genTokens = 0;
}
}

dictLocker.lock();
for (int i = 0; i < handles.size(); i++) {
auto &it = *model->responseContextDict.dicts.find(handles[i]);
Expand Down
3 changes: 3 additions & 0 deletions src/models/chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,9 @@ namespace fastllm {
Data(DataType::FLOAT32)));
}
Forward(inputIds, attentionMask, positionIds, pastKeyValues);
elementsInKVCachePerToken = (long long)block_cnt *
(pastKeyValues[0].first.dims[0] * pastKeyValues[0].first.dims[2] +
pastKeyValues[0].second.dims[0] * pastKeyValues[0].second.dims[2]);
printf("finish.\n");
}

Expand Down
3 changes: 3 additions & 0 deletions src/models/deepseekv2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,9 @@ namespace fastllm {
Data(DataType::FLOAT32)));
}
Forward(inputIds, attentionMask, positionIds, pastKeyValues);
elementsInKVCachePerToken = (long long)block_cnt *
(pastKeyValues[0].first.dims[0] * pastKeyValues[0].first.dims[2] +
pastKeyValues[0].second.dims[0] * pastKeyValues[0].second.dims[2]);
printf("finish.\n");
this->num_experts_per_tok = oldTopk;
}
Expand Down
3 changes: 3 additions & 0 deletions src/models/graphllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,9 @@ namespace fastllm {
Data(DataType::FLOAT32)));
}
Forward(inputIds, attentionMask, positionIds, pastKeyValues);
elementsInKVCachePerToken = (long long)block_cnt *
(pastKeyValues[0].first.dims[0] * pastKeyValues[0].first.dims[2] +
pastKeyValues[0].second.dims[0] * pastKeyValues[0].second.dims[2]);
printf("finish.\n");
}

Expand Down
3 changes: 3 additions & 0 deletions src/models/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,9 @@ namespace fastllm {
this->weight["lm_head.weight"].CopyFrom(this->weight["model.embed_tokens.weight"]);
}
Forward(inputIds, attentionMask, positionIds, pastKeyValues);
elementsInKVCachePerToken = (long long)block_cnt *
(pastKeyValues[0].first.dims[0] * pastKeyValues[0].first.dims[2] +
pastKeyValues[0].second.dims[0] * pastKeyValues[0].second.dims[2]);
printf("finish.\n");
}
}
3 changes: 3 additions & 0 deletions src/models/moe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,9 @@ namespace fastllm {
}
Forward(inputIds, attentionMask, positionIds, pastKeyValues);
this->num_experts_per_tok = oldTopk;
elementsInKVCachePerToken = (long long)block_cnt *
(pastKeyValues[0].first.dims[0] * pastKeyValues[0].first.dims[2] +
pastKeyValues[0].second.dims[0] * pastKeyValues[0].second.dims[2]);
printf("finish.\n");
}
}
3 changes: 3 additions & 0 deletions src/models/moss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,9 @@ namespace fastllm {
Data(DataType::FLOAT32)));
}
Forward(inputIds, attentionMask, positionIds, pastKeyValues);
elementsInKVCachePerToken = (long long)block_cnt *
(pastKeyValues[0].first.dims[0] * pastKeyValues[0].first.dims[2] +
pastKeyValues[0].second.dims[0] * pastKeyValues[0].second.dims[2]);
printf("finish.\n");
}

Expand Down
3 changes: 3 additions & 0 deletions src/models/qwen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,9 @@ namespace fastllm {
Data(DataType::FLOAT32)));
}
Forward(inputIds, attentionMask, positionIds, pastKeyValues);
elementsInKVCachePerToken = (long long)block_cnt *
(pastKeyValues[0].first.dims[0] * pastKeyValues[0].first.dims[2] +
pastKeyValues[0].second.dims[0] * pastKeyValues[0].second.dims[2]);
printf("finish.\n");
}

Expand Down

0 comments on commit e18535c

Please sign in to comment.