diff --git a/include/models/basellm.h b/include/models/basellm.h index c32731c7..716757c1 100644 --- a/include/models/basellm.h +++ b/include/models/basellm.h @@ -27,7 +27,9 @@ namespace fastllm { }; struct ResponseContext { - bool isEnding = false; + bool isEnding = false; // 代表这个请求已经处理完成了,不需要再forward了,但生成的token可能还没有被fetch + bool isAbort = false; // 代表这个请求被中断了,也就是说不会再有人来fetch它了,那么推理完之后就可以删除这个请求了 + std::vector > pastKeyValues; std::vector currentTokens; std::queue resultTokenQueue; @@ -174,6 +176,8 @@ namespace fastllm { virtual int FetchResponseLogits(int handleId, std::vector &logits); // 获取指定handle的输出Logits + virtual void AbortResponse(int handleId); // 中断handleId的请求 + virtual void SaveLowBitModel(const std::string &fileName, int bit); // 存储成量化模型 virtual void SaveModel(const std::string &fileName); // 直接导出 diff --git a/src/models/basellm.cpp b/src/models/basellm.cpp index ad390990..0468bf63 100644 --- a/src/models/basellm.cpp +++ b/src/models/basellm.cpp @@ -552,6 +552,17 @@ namespace fastllm { std::vector * > logits; std::unique_lock dictLocker(model->dictLocker); + + // 首先把已经abort的请求删除掉 + std::set abortHandles; + for (auto &it: model->responseContextDict.dicts) { + if (it.second->isAbort) { + abortHandles.insert(it.first); + } + } + for (auto &it : abortHandles) { + model->responseContextDict.RemoveHandle(it); + } int limit = maxTotalLens; int promptLimit = model->promptLimit; @@ -818,6 +829,17 @@ printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)to return (context->resultTokenQueue.size() > 0 || context->isEnding); } } + + void basellm::AbortResponse(int handleId) { + std::unique_lock dictLocker(this->dictLocker); + ResponseContext *context = responseContextDict.GetHandle(handleId); + + if (context == nullptr) { + return; + } else { + context->isAbort = true; + } + } int basellm::FetchResponseTokens(int handleId) { std::unique_lock dictLocker(this->dictLocker); diff --git a/tools/fastllm_pytools/llm.py b/tools/fastllm_pytools/llm.py index 0bfdab16..396299ee 100644 --- a/tools/fastllm_pytools/llm.py +++ b/tools/fastllm_pytools/llm.py @@ -67,6 +67,8 @@ fastllm_lib.can_fetch_response_llm_model.argtypes = [ctypes.c_int, ctypes.c_int] fastllm_lib.can_fetch_response_llm_model.restype = ctypes.c_bool +fastllm_lib.abort_response_llm_model.argtypes = [ctypes.c_int, ctypes.c_int] + fastllm_lib.make_history_llm_model.argtype = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p] fastllm_lib.make_history_llm_model.restype = ctypes.c_char_p @@ -844,7 +846,7 @@ def add_cache(self, print("add_cache failed: need hf_tokenizer.") exit(0) - async def stream_response_async(self, + def launch_stream_response(self, query: Union[str, List[Dict[str, str]]], history: List[Tuple[str, str]] = None, max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0, @@ -874,7 +876,27 @@ async def stream_response_async(self, handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input), max_length, do_sample, top_p, top_k, temperature, repeat_penalty, False, stop_token_len, stop_token_list) - tokens = []; + return handle + else: + prompt = "" + if (conversation != None and len(conversation) != 0): + prompt = self.apply_chat_template(conversation) + else: + prompt = query if self.direct_query else self.get_prompt(query, history) + stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids) + handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(), + ctypes.c_int(max_length), ctypes.c_bool(do_sample), ctypes.c_float(top_p), ctypes.c_int(top_k), + ctypes.c_float(temperature), ctypes.c_float(repeat_penalty), ctypes.c_bool(False), + stop_token_len, stop_token_list) + return handle + + def abort_handle(self, handle): + fastllm_lib.abort_response_llm_model(self.model, handle) + + async def stream_response_handle_async(self, handle): + if (self.hf_tokenizer != None and hasattr(self.hf_tokenizer, "chat_template") and self.hf_tokenizer.chat_template != ""): + tokenizer = self.hf_tokenizer + tokens = [] while True: if not(fastllm_lib.can_fetch_response_llm_model(self.model, handle)): await asyncio.sleep(0) @@ -894,16 +916,6 @@ async def stream_response_async(self, if len(tokens) > 0: yield tokenizer.decode(tokens) else: - prompt = "" - if (conversation != None and len(conversation) != 0): - prompt = self.apply_chat_template(conversation) - else: - prompt = query if self.direct_query else self.get_prompt(query, history) - stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids) - handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(), - ctypes.c_int(max_length), ctypes.c_bool(do_sample), ctypes.c_float(top_p), ctypes.c_int(top_k), - ctypes.c_float(temperature), ctypes.c_float(repeat_penalty), ctypes.c_bool(False), - stop_token_len, stop_token_list) res = "" ret = b'' fail_cnt = 0 @@ -925,11 +937,17 @@ async def stream_response_async(self, fail_cnt = 0 if (cur == ""): break - if one_by_one: - yield cur - else: - res += cur - yield res + yield cur + + async def stream_response_async(self, + query: Union[str, List[Dict[str, str]]], + history: List[Tuple[str, str]] = None, + max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0, + one_by_one = True, stop_token_ids: List[int] = None, add_generation_prompt = True): + handle = self.launch_stream_response(query, history, max_length, do_sample, top_p, top_k, temperature, + repeat_penalty, one_by_one, stop_token_ids, add_generation_prompt) + async for ret in self.stream_response_handle_async(handle): + yield ret def stream_response_raw(self, input_tokens: List[int], diff --git a/tools/fastllm_pytools/openai_server/fastllm_completion.py b/tools/fastllm_pytools/openai_server/fastllm_completion.py index 1836a75e..f8a34df3 100644 --- a/tools/fastllm_pytools/openai_server/fastllm_completion.py +++ b/tools/fastllm_pytools/openai_server/fastllm_completion.py @@ -9,6 +9,7 @@ import uuid from openai.types.chat import (ChatCompletionContentPartParam, ChatCompletionRole) +from starlette.background import BackgroundTask from .protocal.openai_protocol import * from ftllm import llm @@ -75,7 +76,8 @@ def _parse_chat_message_content( async def create_chat_completion( self, request: ChatCompletionRequest, raw_request: Request ) -> Union[ErrorResponse, AsyncGenerator[str, None], - ChatCompletionResponse]: + ChatCompletionResponse, + Tuple[AsyncGenerator[str, None], AsyncGenerator]]: """Completion API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/chat/create @@ -118,28 +120,40 @@ async def create_chat_completion( frequency_penalty = request.frequency_penalty max_length = request.max_tokens if request.max_tokens else 8192 - input_token_len = self.model.get_input_token_len(messages) #logging.info(request) logging.info(f"fastllm input message: {messages}") #logging.info(f"input tokens: {input_token_len}") + + input_token_len = self.model.get_input_token_len(messages) - result_generator = self.model.stream_response_async(messages, + handle = self.model.launch_stream_response(messages, max_length = max_length, do_sample = True, top_p = request.top_p, top_k = request.top_k, temperature = request.temperature, repeat_penalty = frequency_penalty, one_by_one = True) + result_generator = self.model.stream_response_handle_async(handle) # Streaming response if request.stream: - return self.chat_completion_stream_generator( - request, result_generator, request_id, input_token_len) + return (self.chat_completion_stream_generator( + request, raw_request, result_generator, request_id, input_token_len), + BackgroundTask(self.check_disconnect, raw_request, request_id, handle)) else: try: return await self.chat_completion_full_generator( - request, raw_request, result_generator, request_id, input_token_len) + request, raw_request, handle, result_generator, request_id, input_token_len) except ValueError as e: return self.create_error_response(str(e)) - + + async def check_disconnect(self, raw_request: Request, request_id, handle: int): + while True: + if await raw_request.is_disconnected(): + self.model.abort_handle(handle) + logging.info(f"Abort request: {request_id}") + return + await asyncio.sleep(1) # 检查间隔 + async def chat_completion_full_generator( self, request: ChatCompletionRequest, raw_request: Request, + handle: int, result_generator: AsyncIterator, request_id: str, input_token_len: int) -> Union[ErrorResponse, ChatCompletionResponse]: @@ -150,6 +164,10 @@ async def chat_completion_full_generator( async for res in result_generator: result += res completion_tokens += 1 + if await raw_request.is_disconnected(): + self.model.abort_handle(handle) + logging.info(f"Abort request: {request_id}") + return self.create_error_response("Client disconnected") choice_data = ChatCompletionResponseChoice( index=0, @@ -173,7 +191,7 @@ async def chat_completion_full_generator( async def chat_completion_stream_generator( - self, request: ChatCompletionRequest, + self, request: ChatCompletionRequest, raw_request: Request, result_generator: AsyncIterator, request_id: str, input_token_len: int) -> AsyncGenerator[str, None]: diff --git a/tools/fastllm_pytools/server.py b/tools/fastllm_pytools/server.py index 82b9225f..2c6ae3a4 100644 --- a/tools/fastllm_pytools/server.py +++ b/tools/fastllm_pytools/server.py @@ -37,14 +37,15 @@ async def create_chat_completion(request: ChatCompletionRequest, generator = await fastllm_completion.create_chat_completion( request, raw_request) if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.code) + return JSONResponse(content = generator.model_dump(), + status_code = generator.code) if request.stream: - return StreamingResponse(content=generator, - media_type="text/event-stream") + return StreamingResponse(content = generator[0], + background = generator[1], + media_type = "text/event-stream") else: assert isinstance(generator, ChatCompletionResponse) - return JSONResponse(content=generator.model_dump()) + return JSONResponse(content = generator.model_dump()) def init_logging(log_level = logging.INFO, log_file:str = None): logging_format = '%(asctime)s %(process)d %(filename)s[line:%(lineno)d] %(levelname)s: %(message)s' diff --git a/tools/src/pytools.cpp b/tools/src/pytools.cpp index e7affd5d..4e4ba012 100644 --- a/tools/src/pytools.cpp +++ b/tools/src/pytools.cpp @@ -352,6 +352,12 @@ extern "C" { return model->CanFetchResponse(handleId); } + // 终止handleId的请求 + DLL_EXPORT void abort_response_llm_model(int modelId, int handleId) { + auto model = models.GetModel(modelId); + model->AbortResponse(handleId); + } + DLL_EXPORT char *fetch_response_str_llm_model(int modelId, int handleId) { auto model = models.GetModel(modelId); int ret = model->FetchResponseTokens(handleId);