From 74908028e7e3c3b04f12e914aacf6adce107d9f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Wed, 24 Jul 2024 16:34:42 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E5=A4=9A=E5=8D=A1=E5=89=A9=E4=BD=99?= =?UTF-8?q?=E6=98=BE=E5=AD=98=E8=AE=A1=E7=AE=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/devices/cuda/fastllm-cuda.cu | 5 +++++ src/models/basellm.cpp | 22 ++++++++++++++++++++-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/devices/cuda/fastllm-cuda.cu b/src/devices/cuda/fastllm-cuda.cu index bd7b3c99..e740ee43 100644 --- a/src/devices/cuda/fastllm-cuda.cu +++ b/src/devices/cuda/fastllm-cuda.cu @@ -120,6 +120,9 @@ std::vector FastllmCudaGetFreeSizes() { std::vector ret; // 遍历所有设备 + int id = -1; + cudaGetDevice(&id); + for (int i = 0; i < deviceCount; ++i) { cudaDeviceProp prop; error = cudaGetDeviceProperties(&prop, i); @@ -129,6 +132,7 @@ std::vector FastllmCudaGetFreeSizes() { // printf(" Total global memory: %zu bytes\n", prop.totalGlobalMem); // 获取当前设备的显存使用情况 + cudaSetDevice(i); size_t free = 0, total = 0; cudaMemGetInfo(&free, &total); ret.push_back(free); @@ -138,6 +142,7 @@ std::vector FastllmCudaGetFreeSizes() { printf("cudaGetDeviceProperties returned %d\n-> %s\n", (int)error, cudaGetErrorString(error)); } } + cudaSetDevice(id); return ret; } diff --git a/src/models/basellm.cpp b/src/models/basellm.cpp index 321666ab..4f424bb6 100644 --- a/src/models/basellm.cpp +++ b/src/models/basellm.cpp @@ -489,9 +489,21 @@ namespace fastllm { long long kvCacheLimit = 16LL << 30; #ifdef USE_CUDA auto freeSizes = FastllmCudaGetFreeSizes(); + auto dmap = GetDeviceMap(); + std::set deviceIds; + for (auto &it : dmap) { + if (StartWith(it.first, "cuda")) { + for (int id : ParseDeviceIds(it.first, "cuda")) { + deviceIds.insert(id); + } + } + } + if (deviceIds.size() == 0) { + deviceIds.insert(0); + } kvCacheLimit = 0; - for (long long i : freeSizes) { - kvCacheLimit += std::max(0LL, i - (2LL << 30)); + for (int id : deviceIds) { + kvCacheLimit += std::max(0LL, freeSizes[id] - (2LL << 30)); } #endif if (model->kvCacheLimit > 0) { @@ -803,10 +815,14 @@ printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)to } else { if (context->isEnding) { responseContextDict.RemoveHandle(handleId); + dictLocker.unlock(); + dictCV.notify_one(); if (context->error == ResponseContextErrorNone) { return -1; } else if (context->error == ResponseContextErrorPromptTooLong) { return -2; + } else { + return -1; } } } @@ -836,6 +852,8 @@ printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)to } else { if (context->isEnding) { responseContextDict.RemoveHandle(handleId); + dictLocker.unlock(); + dictCV.notify_one(); return -1; } }