Skip to content

Commit

Permalink
server增加rerank接口
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Sep 25, 2024
1 parent e0e9a6b commit c9ae3fa
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 4 deletions.
30 changes: 30 additions & 0 deletions tools/fastllm_pytools/openai_server/fastllm_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import asyncio
import logging
import json
import traceback
from fastapi import Request

from .protocal.openai_protocol import *
from ftllm import llm

class FastLLmReranker:
def __init__(self,
model_name,
model):
self.model_name = model_name
self.model = model

def rerank(self, request: RerankRequest, raw_request: Request):
query = request.query
pairs = []
for text in request.texts:
pairs.append([query, text])
scores = self.model.reranker_compute_score(pairs = pairs)
ret = []
for i in range(len(request.texts)):
now = {'index': i, 'score': scores[i]}
if (request.return_text):
now['text'] = request.texts[i]
ret.append(now)
ret = sorted(ret, key = lambda x : -x['score'])
return ret
16 changes: 12 additions & 4 deletions tools/fastllm_pytools/openai_server/protocal/openai_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,15 @@ class CompletionStreamResponse(BaseModel):

class EmbedRequest(BaseModel):
inputs: str
normalize: Optional[bool]
prompt_name: Optional[str]
truncate: Optional[bool]
truncation_direction: Optional[str]
normalize: Optional[bool] = False
prompt_name: Optional[str] = "null"
truncate: Optional[bool] = False
truncation_direction: Optional[str] = 'right'

class RerankRequest(BaseModel):
query: str
texts: List[str]
raw_scores: Optional[bool] = True
return_text: Optional[bool] = False
truncate: Optional[bool] = False
truncation_direction: Optional[str] = "right"
9 changes: 9 additions & 0 deletions tools/fastllm_pytools/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .openai_server.protocal.openai_protocol import *
from .openai_server.fastllm_completion import FastLLmCompletion
from .openai_server.fastllm_embed import FastLLmEmbed
from .openai_server.fastllm_reranker import FastLLmReranker
from .util import make_normal_parser
from .util import make_normal_llm_model

Expand Down Expand Up @@ -55,6 +56,13 @@ async def create_embed(request: EmbedRequest,
embedding = fastllm_embed.embedding_sentence(request, raw_request)
return JSONResponse(embedding)

@app.post("/v1/rerank")
async def create_rerank(request: RerankRequest,
raw_request: Request):
print(request)
scores = fastllm_reranker.rerank(request, raw_request)
return JSONResponse(scores)

def init_logging(log_level = logging.INFO, log_file:str = None):
logging_format = '%(asctime)s %(process)d %(filename)s[line:%(lineno)d] %(levelname)s: %(message)s'
root = logging.getLogger()
Expand All @@ -74,4 +82,5 @@ def init_logging(log_level = logging.INFO, log_file:str = None):
model.set_verbose(True)
fastllm_completion = FastLLmCompletion(model_name = args.model_name, model = model)
fastllm_embed = FastLLmEmbed(model_name = args.model_name, model = model)
fastllm_reranker = FastLLmReranker(model_name = args.model_name, model = model)
uvicorn.run(app, host = args.host, port = args.port)

0 comments on commit c9ae3fa

Please sign in to comment.