diff --git a/include/models/basellm.h b/include/models/basellm.h index 33960c06..c32731c7 100644 --- a/include/models/basellm.h +++ b/include/models/basellm.h @@ -22,6 +22,10 @@ using RuntimeResultBatch = std::function >; + enum ResponseContextError { + ResponseContextErrorNone = 0, ResponseContextErrorPromptTooLong + }; + struct ResponseContext { bool isEnding = false; std::vector > pastKeyValues; @@ -30,6 +34,7 @@ namespace fastllm { std::queue *> resultLogits; GenerationConfig generationConfig; LastTokensUnit tokens; + ResponseContextError error = ResponseContextErrorNone; int preTokens = 0; int curTokens = 0; @@ -232,6 +237,7 @@ namespace fastllm { std::string adapterName; int tokensLimit = -1; + int promptLimit = -1; PastKVCacheManager pastKVCacheManager; bool saveHistoryChat = false; diff --git a/src/models/basellm.cpp b/src/models/basellm.cpp index ab1f3862..321666ab 100644 --- a/src/models/basellm.cpp +++ b/src/models/basellm.cpp @@ -513,9 +513,13 @@ namespace fastllm { } model->tokensLimit = maxTotalLens; + int limit = maxTotalLens; + model->promptLimit = limit * 2 / 3; + if (model->verbose) { printf("Fastllm KV Cache Limit: %f MB.\n", (double)kvCacheLimit / 1024 / 1024); printf("Fastllm KV Cache Token limit: %d tokens.\n", maxTotalLens); + printf("Fastllm Prompt Token limit: %d tokens.\n", std::min(model->max_positions, model->promptLimit)); printf("Fastllm Batch limit: %d.\n", maxBatch); } @@ -538,7 +542,7 @@ namespace fastllm { std::unique_lock dictLocker(model->dictLocker); int limit = maxTotalLens; - int promptLimit = limit * 2 / 3; + int promptLimit = model->promptLimit; int lenSum = 0; for (auto &it: model->responseContextDict.dicts) { @@ -566,6 +570,11 @@ namespace fastllm { if (!isPrompt && it.second->preTokens == 0) { continue; } + if (it.second->currentTokens.size() > promptLimit) { + it.second->isEnding = true; + it.second->error = ResponseContextErrorPromptTooLong; + continue; + } int outputLimit = it.second->generationConfig.output_token_limit; outputLimit = (outputLimit < 0 ? 128 : outputLimit); @@ -794,7 +803,11 @@ printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)to } else { if (context->isEnding) { responseContextDict.RemoveHandle(handleId); - return -1; + if (context->error == ResponseContextErrorNone) { + return -1; + } else if (context->error == ResponseContextErrorPromptTooLong) { + return -2; + } } } dictLocker.unlock(); diff --git a/tools/fastllm_pytools/llm.py b/tools/fastllm_pytools/llm.py index 93ef893a..851dd323 100644 --- a/tools/fastllm_pytools/llm.py +++ b/tools/fastllm_pytools/llm.py @@ -496,7 +496,7 @@ def generate( for i in range(len(inputs)): while True: cur_token = fastllm_lib.fetch_response_llm_model(self.model, handles[i]) - if cur_token == -1: + if cur_token <= -1: break outputs[i].append(cur_token) return outputs @@ -697,7 +697,7 @@ def stream_response(self, if not(fastllm_lib.can_fetch_response_llm_model(self.model, handle)): continue cur = fastllm_lib.fetch_response_llm_model(self.model, handle) - if (cur == -1): + if (cur <= -1): break tokens.append(cur) ret = tokenizer.decode(tokens) @@ -789,8 +789,10 @@ async def stream_response_async(self, await asyncio.sleep(0) continue cur = fastllm_lib.fetch_response_llm_model(self.model, handle) - if (cur == -1): - break; + if (cur <= -1): + if (cur == -2): + yield "prompt too long" + break tokens.append(cur) ret = tokenizer.decode(tokens) if (ret.encode().find(b'\xef\xbf\xbd') == -1): @@ -857,7 +859,7 @@ def stream_response_raw(self, total_bytes = b'' while True: cur_token = fastllm_lib.fetch_response_llm_model(self.model, handle) - if cur_token == -1: + if cur_token <= -1: break cur_bytes = self.tokenizer_decode_token(cur_token) @@ -882,7 +884,7 @@ def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max result = []; while True: cur = fastllm_lib.fetch_response_llm_model(self.model, handle); - if (cur == -1): + if (cur <= -1): break; result.append(cur); response = tokenizer.decode(result); @@ -914,7 +916,7 @@ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = No tokens = []; while True: cur = fastllm_lib.fetch_response_llm_model(self.model, handle); - if (cur == -1): + if (cur <= -1): break; tokens.append(cur); response = tokenizer.decode(tokens);