Skip to content

Commit

Permalink
fix: 多卡剩余显存计算
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jul 24, 2024
1 parent b3b3503 commit 7490802
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
5 changes: 5 additions & 0 deletions src/devices/cuda/fastllm-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ std::vector <long long> FastllmCudaGetFreeSizes() {
std::vector <long long> ret;

// 遍历所有设备
int id = -1;
cudaGetDevice(&id);

for (int i = 0; i < deviceCount; ++i) {
cudaDeviceProp prop;
error = cudaGetDeviceProperties(&prop, i);
Expand All @@ -129,6 +132,7 @@ std::vector <long long> FastllmCudaGetFreeSizes() {
// printf(" Total global memory: %zu bytes\n", prop.totalGlobalMem);

// 获取当前设备的显存使用情况
cudaSetDevice(i);
size_t free = 0, total = 0;
cudaMemGetInfo(&free, &total);
ret.push_back(free);
Expand All @@ -138,6 +142,7 @@ std::vector <long long> FastllmCudaGetFreeSizes() {
printf("cudaGetDeviceProperties returned %d\n-> %s\n", (int)error, cudaGetErrorString(error));
}
}
cudaSetDevice(id);
return ret;
}

Expand Down
22 changes: 20 additions & 2 deletions src/models/basellm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,9 +489,21 @@ namespace fastllm {
long long kvCacheLimit = 16LL << 30;
#ifdef USE_CUDA
auto freeSizes = FastllmCudaGetFreeSizes();
auto dmap = GetDeviceMap();
std::set <int> deviceIds;
for (auto &it : dmap) {
if (StartWith(it.first, "cuda")) {
for (int id : ParseDeviceIds(it.first, "cuda")) {
deviceIds.insert(id);
}
}
}
if (deviceIds.size() == 0) {
deviceIds.insert(0);
}
kvCacheLimit = 0;
for (long long i : freeSizes) {
kvCacheLimit += std::max(0LL, i - (2LL << 30));
for (int id : deviceIds) {
kvCacheLimit += std::max(0LL, freeSizes[id] - (2LL << 30));
}
#endif
if (model->kvCacheLimit > 0) {
Expand Down Expand Up @@ -803,10 +815,14 @@ printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)to
} else {
if (context->isEnding) {
responseContextDict.RemoveHandle(handleId);
dictLocker.unlock();
dictCV.notify_one();
if (context->error == ResponseContextErrorNone) {
return -1;
} else if (context->error == ResponseContextErrorPromptTooLong) {
return -2;
} else {
return -1;
}
}
}
Expand Down Expand Up @@ -836,6 +852,8 @@ printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)to
} else {
if (context->isEnding) {
responseContextDict.RemoveHandle(handleId);
dictLocker.unlock();
dictCV.notify_one();
return -1;
}
}
Expand Down

0 comments on commit 7490802

Please sign in to comment.