diff --git a/tools/fastllm_pytools/openai_server/fastllm_reranker.py b/tools/fastllm_pytools/openai_server/fastllm_reranker.py new file mode 100644 index 0000000..dbbdb8b --- /dev/null +++ b/tools/fastllm_pytools/openai_server/fastllm_reranker.py @@ -0,0 +1,30 @@ +import asyncio +import logging +import json +import traceback +from fastapi import Request + +from .protocal.openai_protocol import * +from ftllm import llm + +class FastLLmReranker: + def __init__(self, + model_name, + model): + self.model_name = model_name + self.model = model + + def rerank(self, request: RerankRequest, raw_request: Request): + query = request.query + pairs = [] + for text in request.texts: + pairs.append([query, text]) + scores = self.model.reranker_compute_score(pairs = pairs) + ret = [] + for i in range(len(request.texts)): + now = {'index': i, 'score': scores[i]} + if (request.return_text): + now['text'] = request.texts[i] + ret.append(now) + ret = sorted(ret, key = lambda x : -x['score']) + return ret diff --git a/tools/fastllm_pytools/openai_server/protocal/openai_protocol.py b/tools/fastllm_pytools/openai_server/protocal/openai_protocol.py index 36cbe98..a41bb48 100644 --- a/tools/fastllm_pytools/openai_server/protocal/openai_protocol.py +++ b/tools/fastllm_pytools/openai_server/protocal/openai_protocol.py @@ -203,7 +203,15 @@ class CompletionStreamResponse(BaseModel): class EmbedRequest(BaseModel): inputs: str - normalize: Optional[bool] - prompt_name: Optional[str] - truncate: Optional[bool] - truncation_direction: Optional[str] + normalize: Optional[bool] = False + prompt_name: Optional[str] = "null" + truncate: Optional[bool] = False + truncation_direction: Optional[str] = 'right' + +class RerankRequest(BaseModel): + query: str + texts: List[str] + raw_scores: Optional[bool] = True + return_text: Optional[bool] = False + truncate: Optional[bool] = False + truncation_direction: Optional[str] = "right" diff --git a/tools/fastllm_pytools/server.py b/tools/fastllm_pytools/server.py index 7ab1e4f..415190b 100644 --- a/tools/fastllm_pytools/server.py +++ b/tools/fastllm_pytools/server.py @@ -10,6 +10,7 @@ from .openai_server.protocal.openai_protocol import * from .openai_server.fastllm_completion import FastLLmCompletion from .openai_server.fastllm_embed import FastLLmEmbed +from .openai_server.fastllm_reranker import FastLLmReranker from .util import make_normal_parser from .util import make_normal_llm_model @@ -55,6 +56,13 @@ async def create_embed(request: EmbedRequest, embedding = fastllm_embed.embedding_sentence(request, raw_request) return JSONResponse(embedding) +@app.post("/v1/rerank") +async def create_rerank(request: RerankRequest, + raw_request: Request): + print(request) + scores = fastllm_reranker.rerank(request, raw_request) + return JSONResponse(scores) + def init_logging(log_level = logging.INFO, log_file:str = None): logging_format = '%(asctime)s %(process)d %(filename)s[line:%(lineno)d] %(levelname)s: %(message)s' root = logging.getLogger() @@ -74,4 +82,5 @@ def init_logging(log_level = logging.INFO, log_file:str = None): model.set_verbose(True) fastllm_completion = FastLLmCompletion(model_name = args.model_name, model = model) fastllm_embed = FastLLmEmbed(model_name = args.model_name, model = model) + fastllm_reranker = FastLLmReranker(model_name = args.model_name, model = model) uvicorn.run(app, host = args.host, port = args.port)