From 2ed99931977c45b69ce54262de8fbeff4de368ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Fri, 13 Dec 2024 16:19:02 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=B8=80=E4=B8=8B=E8=B0=83?= =?UTF-8?q?=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/models/basellm.cpp | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/models/basellm.cpp b/src/models/basellm.cpp index 1494cc9..57af96b 100644 --- a/src/models/basellm.cpp +++ b/src/models/basellm.cpp @@ -588,6 +588,11 @@ namespace fastllm { currentActivate++; } } + std::vector > orders; + for (auto &it : model->responseContextDict.dicts) { + orders.push_back(std::make_pair(-(int)it.second->currentTokens.size(), it.first)); + } + sort(orders.begin(), orders.end()); for (int isPrompt = 1; isPrompt >= 0; isPrompt--) { int cnt = 0; @@ -599,7 +604,12 @@ namespace fastllm { continue; } */ - for (auto &it: model->responseContextDict.dicts) { + + int currentMaxLen = 0; + + // for (auto &it: model->responseContextDict.dicts) { + for (auto &ii : orders) { + auto &it = *model->responseContextDict.dicts.find(ii.second); if (it.second->isEnding) { continue; } @@ -640,6 +650,10 @@ namespace fastllm { lenSum += predictLen; } } else { + if (it.second->currentTokens.size() * 2 < currentMaxLen) { + continue; + } + currentMaxLen = std::max(currentMaxLen, (int)it.second->currentTokens.size()); lenSum += it.second->currentTokens.size(); currentActivate++; }