Skip to content

Commit

Permalink
支持中断请求
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jul 30, 2024
1 parent 8eb8197 commit 49e45e3
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 31 deletions.
6 changes: 5 additions & 1 deletion include/models/basellm.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ namespace fastllm {
};

struct ResponseContext {
bool isEnding = false;
bool isEnding = false; // 代表这个请求已经处理完成了,不需要再forward了,但生成的token可能还没有被fetch
bool isAbort = false; // 代表这个请求被中断了,也就是说不会再有人来fetch它了,那么推理完之后就可以删除这个请求了

std::vector <std::pair <Data, Data> > pastKeyValues;
std::vector <int> currentTokens;
std::queue <int> resultTokenQueue;
Expand Down Expand Up @@ -174,6 +176,8 @@ namespace fastllm {

virtual int FetchResponseLogits(int handleId, std::vector <float> &logits); // 获取指定handle的输出Logits

virtual void AbortResponse(int handleId); // 中断handleId的请求

virtual void SaveLowBitModel(const std::string &fileName, int bit); // 存储成量化模型

virtual void SaveModel(const std::string &fileName); // 直接导出
Expand Down
22 changes: 22 additions & 0 deletions src/models/basellm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,17 @@ namespace fastllm {
std::vector <std::vector <float>* > logits;

std::unique_lock<std::mutex> dictLocker(model->dictLocker);

// 首先把已经abort的请求删除掉
std::set <int> abortHandles;
for (auto &it: model->responseContextDict.dicts) {
if (it.second->isAbort) {
abortHandles.insert(it.first);
}
}
for (auto &it : abortHandles) {
model->responseContextDict.RemoveHandle(it);
}

int limit = maxTotalLens;
int promptLimit = model->promptLimit;
Expand Down Expand Up @@ -818,6 +829,17 @@ printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)to
return (context->resultTokenQueue.size() > 0 || context->isEnding);
}
}

void basellm::AbortResponse(int handleId) {
std::unique_lock<std::mutex> dictLocker(this->dictLocker);
ResponseContext *context = responseContextDict.GetHandle(handleId);

if (context == nullptr) {
return;
} else {
context->isAbort = true;
}
}

int basellm::FetchResponseTokens(int handleId) {
std::unique_lock<std::mutex> dictLocker(this->dictLocker);
Expand Down
52 changes: 35 additions & 17 deletions tools/fastllm_pytools/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@
fastllm_lib.can_fetch_response_llm_model.argtypes = [ctypes.c_int, ctypes.c_int]
fastllm_lib.can_fetch_response_llm_model.restype = ctypes.c_bool

fastllm_lib.abort_response_llm_model.argtypes = [ctypes.c_int, ctypes.c_int]

fastllm_lib.make_history_llm_model.argtype = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p]
fastllm_lib.make_history_llm_model.restype = ctypes.c_char_p

Expand Down Expand Up @@ -844,7 +846,7 @@ def add_cache(self,
print("add_cache failed: need hf_tokenizer.")
exit(0)

async def stream_response_async(self,
def launch_stream_response(self,
query: Union[str, List[Dict[str, str]]],
history: List[Tuple[str, str]] = None,
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0,
Expand Down Expand Up @@ -874,7 +876,27 @@ async def stream_response_async(self,
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
False, stop_token_len, stop_token_list)
tokens = [];
return handle
else:
prompt = ""
if (conversation != None and len(conversation) != 0):
prompt = self.apply_chat_template(conversation)
else:
prompt = query if self.direct_query else self.get_prompt(query, history)
stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids)
handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(),
ctypes.c_int(max_length), ctypes.c_bool(do_sample), ctypes.c_float(top_p), ctypes.c_int(top_k),
ctypes.c_float(temperature), ctypes.c_float(repeat_penalty), ctypes.c_bool(False),
stop_token_len, stop_token_list)
return handle

def abort_handle(self, handle):
fastllm_lib.abort_response_llm_model(self.model, handle)

async def stream_response_handle_async(self, handle):
if (self.hf_tokenizer != None and hasattr(self.hf_tokenizer, "chat_template") and self.hf_tokenizer.chat_template != ""):
tokenizer = self.hf_tokenizer
tokens = []
while True:
if not(fastllm_lib.can_fetch_response_llm_model(self.model, handle)):
await asyncio.sleep(0)
Expand All @@ -894,16 +916,6 @@ async def stream_response_async(self,
if len(tokens) > 0:
yield tokenizer.decode(tokens)
else:
prompt = ""
if (conversation != None and len(conversation) != 0):
prompt = self.apply_chat_template(conversation)
else:
prompt = query if self.direct_query else self.get_prompt(query, history)
stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids)
handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(),
ctypes.c_int(max_length), ctypes.c_bool(do_sample), ctypes.c_float(top_p), ctypes.c_int(top_k),
ctypes.c_float(temperature), ctypes.c_float(repeat_penalty), ctypes.c_bool(False),
stop_token_len, stop_token_list)
res = ""
ret = b''
fail_cnt = 0
Expand All @@ -925,11 +937,17 @@ async def stream_response_async(self,
fail_cnt = 0
if (cur == "<flmeos>"):
break
if one_by_one:
yield cur
else:
res += cur
yield res
yield cur

async def stream_response_async(self,
query: Union[str, List[Dict[str, str]]],
history: List[Tuple[str, str]] = None,
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0,
one_by_one = True, stop_token_ids: List[int] = None, add_generation_prompt = True):
handle = self.launch_stream_response(query, history, max_length, do_sample, top_p, top_k, temperature,
repeat_penalty, one_by_one, stop_token_ids, add_generation_prompt)
async for ret in self.stream_response_handle_async(handle):
yield ret

def stream_response_raw(self,
input_tokens: List[int],
Expand Down
34 changes: 26 additions & 8 deletions tools/fastllm_pytools/openai_server/fastllm_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import uuid
from openai.types.chat import (ChatCompletionContentPartParam,
ChatCompletionRole)
from starlette.background import BackgroundTask

from .protocal.openai_protocol import *
from ftllm import llm
Expand Down Expand Up @@ -75,7 +76,8 @@ def _parse_chat_message_content(
async def create_chat_completion(
self, request: ChatCompletionRequest, raw_request: Request
) -> Union[ErrorResponse, AsyncGenerator[str, None],
ChatCompletionResponse]:
ChatCompletionResponse,
Tuple[AsyncGenerator[str, None], AsyncGenerator]]:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create
Expand Down Expand Up @@ -118,28 +120,40 @@ async def create_chat_completion(
frequency_penalty = request.frequency_penalty

max_length = request.max_tokens if request.max_tokens else 8192
input_token_len = self.model.get_input_token_len(messages)
#logging.info(request)
logging.info(f"fastllm input message: {messages}")
#logging.info(f"input tokens: {input_token_len}")

input_token_len = self.model.get_input_token_len(messages)

result_generator = self.model.stream_response_async(messages,
handle = self.model.launch_stream_response(messages,
max_length = max_length, do_sample = True,
top_p = request.top_p, top_k = request.top_k, temperature = request.temperature,
repeat_penalty = frequency_penalty, one_by_one = True)
result_generator = self.model.stream_response_handle_async(handle)
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id, input_token_len)
return (self.chat_completion_stream_generator(
request, raw_request, result_generator, request_id, input_token_len),
BackgroundTask(self.check_disconnect, raw_request, request_id, handle))
else:
try:
return await self.chat_completion_full_generator(
request, raw_request, result_generator, request_id, input_token_len)
request, raw_request, handle, result_generator, request_id, input_token_len)
except ValueError as e:
return self.create_error_response(str(e))


async def check_disconnect(self, raw_request: Request, request_id, handle: int):
while True:
if await raw_request.is_disconnected():
self.model.abort_handle(handle)
logging.info(f"Abort request: {request_id}")
return
await asyncio.sleep(1) # 检查间隔

async def chat_completion_full_generator(
self, request: ChatCompletionRequest, raw_request: Request,
handle: int,
result_generator: AsyncIterator,
request_id: str,
input_token_len: int) -> Union[ErrorResponse, ChatCompletionResponse]:
Expand All @@ -150,6 +164,10 @@ async def chat_completion_full_generator(
async for res in result_generator:
result += res
completion_tokens += 1
if await raw_request.is_disconnected():
self.model.abort_handle(handle)
logging.info(f"Abort request: {request_id}")
return self.create_error_response("Client disconnected")

choice_data = ChatCompletionResponseChoice(
index=0,
Expand All @@ -173,7 +191,7 @@ async def chat_completion_full_generator(


async def chat_completion_stream_generator(
self, request: ChatCompletionRequest,
self, request: ChatCompletionRequest, raw_request: Request,
result_generator: AsyncIterator,
request_id: str,
input_token_len: int) -> AsyncGenerator[str, None]:
Expand Down
11 changes: 6 additions & 5 deletions tools/fastllm_pytools/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,15 @@ async def create_chat_completion(request: ChatCompletionRequest,
generator = await fastllm_completion.create_chat_completion(
request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
return JSONResponse(content = generator.model_dump(),
status_code = generator.code)
if request.stream:
return StreamingResponse(content=generator,
media_type="text/event-stream")
return StreamingResponse(content = generator[0],
background = generator[1],
media_type = "text/event-stream")
else:
assert isinstance(generator, ChatCompletionResponse)
return JSONResponse(content=generator.model_dump())
return JSONResponse(content = generator.model_dump())

def init_logging(log_level = logging.INFO, log_file:str = None):
logging_format = '%(asctime)s %(process)d %(filename)s[line:%(lineno)d] %(levelname)s: %(message)s'
Expand Down
6 changes: 6 additions & 0 deletions tools/src/pytools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,12 @@ extern "C" {
return model->CanFetchResponse(handleId);
}

// 终止handleId的请求
DLL_EXPORT void abort_response_llm_model(int modelId, int handleId) {
auto model = models.GetModel(modelId);
model->AbortResponse(handleId);
}

DLL_EXPORT char *fetch_response_str_llm_model(int modelId, int handleId) {
auto model = models.GetModel(modelId);
int ret = model->FetchResponseTokens(handleId);
Expand Down

0 comments on commit 49e45e3

Please sign in to comment.