diff --git a/src/devices/cuda/fastllm-cuda.cu b/src/devices/cuda/fastllm-cuda.cu index 14c9aa68..b273074d 100644 --- a/src/devices/cuda/fastllm-cuda.cu +++ b/src/devices/cuda/fastllm-cuda.cu @@ -2620,8 +2620,25 @@ void FastllmCudaClearBigBuffer() { for (auto &it : bigBuffersMap) { auto &bigBuffers = it.second; std::vector temp; + long long littleMemSum = 0; + long long littleMemSumLimit = 300 * 1024 * 1024; // 留一小部分复用 + std::vector > v; for (int i = 0; i < bigBuffers.size(); i++) { if (!bigBuffers[i].busy) { + v.push_back(std::make_pair(bigBuffers[i].size, i)); + } + } + sort(v.begin(), v.end()); + std::set littleMemIds; + for (int i = 0; i < v.size(); i++) { + littleMemSum += v[i].first; + if (littleMemSum > littleMemSumLimit) { + break; + } + littleMemIds.insert(v[i].second); + } + for (int i = 0; i < bigBuffers.size(); i++) { + if (!bigBuffers[i].busy && littleMemIds.find(i) == littleMemIds.end()) { state = cudaSetDevice(it.first); state = cudaFree(bigBuffers[i].data); if (cudaSuccess != state)