Skip to content

Commit

Permalink
保留一部分cache
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jul 8, 2024
1 parent 5700447 commit 77b8471
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions src/devices/cuda/fastllm-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2620,8 +2620,25 @@ void FastllmCudaClearBigBuffer() {
for (auto &it : bigBuffersMap) {
auto &bigBuffers = it.second;
std::vector <CudaMemoryBuffer> temp;
long long littleMemSum = 0;
long long littleMemSumLimit = 300 * 1024 * 1024; // 留一小部分复用
std::vector <std::pair <std::size_t, int > > v;
for (int i = 0; i < bigBuffers.size(); i++) {
if (!bigBuffers[i].busy) {
v.push_back(std::make_pair(bigBuffers[i].size, i));
}
}
sort(v.begin(), v.end());
std::set <int> littleMemIds;
for (int i = 0; i < v.size(); i++) {
littleMemSum += v[i].first;
if (littleMemSum > littleMemSumLimit) {
break;
}
littleMemIds.insert(v[i].second);
}
for (int i = 0; i < bigBuffers.size(); i++) {
if (!bigBuffers[i].busy && littleMemIds.find(i) == littleMemIds.end()) {
state = cudaSetDevice(it.first);
state = cudaFree(bigBuffers[i].data);
if (cudaSuccess != state)
Expand Down

0 comments on commit 77b8471

Please sign in to comment.