Skip to content

Commit

Permalink
update server
Browse files Browse the repository at this point in the history
  • Loading branch information
AlongWY committed Jun 8, 2024
1 parent 491a5c1 commit 1ee20dd
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 106 deletions.
1 change: 0 additions & 1 deletion .github/workflows/ltp-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ jobs:
release:
name: Release
runs-on: ubuntu-latest
if: "startsWith(github.ref, 'refs/tags/')"
needs: [interface]
steps:
- uses: actions/download-artifact@v2
Expand Down
174 changes: 69 additions & 105 deletions python/interface/examples/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,70 +9,71 @@
python tools/server.py serve
"""

import sys
import json
import logging
from typing import List

from typing import List, Union
import torch
from fastapi import FastAPI
from pydantic import BaseModel
from ltp import LTP

from tornado import ioloop
from tornado.httpserver import HTTPServer
from tornado.web import Application, RequestHandler
from tornado.log import app_log, gen_log, access_log, LogFormatter
from fire import Fire

from ltp import LTP
class SRLRole(BaseModel):
text: str
offset: int
length: int
type: str


class Parent(BaseModel):
parent: int
relate: str


class Word(BaseModel):
id: int
length: int
offset: int
text: str
pos: str
parent: int
relation: str
roles: List[SRLRole]
parents: List[Parent]


class NE(BaseModel):
text: str
offset: int
ne: str
length: int


class Item(BaseModel):
text: str
nes: List[NE]
words: List[Word]


app = FastAPI()

ltp = LTP("LTP/tiny")

if torch.cuda.is_available():
ltp.to("cuda")

class LTPHandler(RequestHandler):
def set_default_headers(self):
self.set_header("Access-Control-Allow-Origin", "*")
self.set_header('Access-Control-Allow-Headers', 'Content-Type')
self.set_header('Access-Control-Allow-Methods', 'GET, POST, PUT, DELETE, PATCH, OPTIONS')
self.set_header('Content-Type', 'application/json;charset=UTF-8')

def initialize(self, ltp):
self.set_default_headers()
self.ltp = ltp

def post(self):
try:
print(self.request.body.decode('utf-8'))
text = json.loads(self.request.body.decode('utf-8'))['text']
# print(text)
result = self.ltp._predict([text])
# print(result)
self.finish(result)
except Exception as e:
self.finish(self.ltp._predict(['服务器遇到错误!'])[0])

def options(self):
pass


class Server(object):
def __init__(self, path: str = 'LTP/tiny', batch_size: int = 50, device: str = None):
# 2024/6/1 7:9:45 adapt for "ltp==4.2.13"
self.ltp = LTP(path)
self.batch_size = batch_size
# 将模型移动到 GPU 上
if device is None and torch.cuda.is_available():
# ltp.cuda()
self.ltp.to("cuda")
elif device is not None:
self.ltp.to(device)

def _predict(self, sentences: List[str]):
output = self.ltp.pipeline(sentences, tasks=["cws", "pos", "ner", "srl", "dep", "sdp", "sdpg"])

# https://github.com/HIT-SCIR/ltp/blob/main/python/interface/docs/quickstart.rst
# 需要注意的是,在依存句法当中,虚节点ROOT占据了0位置,因此节点的下标从1开始。

@app.post("/api")
async def predict(sentences: List[str]) -> List[Item]:
output = ltp.pipeline(sentences, tasks=["cws", "pos", "ner", "srl", "dep", "sdp", "sdpg"])

# https://github.com/HIT-SCIR/ltp/blob/main/python/interface/docs/quickstart.rst
# 需要注意的是,在依存句法当中,虚节点ROOT占据了0位置,因此节点的下标从1开始。
result = []
for idx, sentence in enumerate(sentences):
id = 0
offset = 0
words = []
for word, pos, parent, relation in \
zip(output.cws[0], output.pos[0], output.dep[0]['head'], output.dep[0]['label']):
zip(output.cws[idx], output.pos[idx], output.dep[idx]['head'], output.dep[idx]['label']):
# print([id, word, pos, parent, relation])
words.append({
'id': id,
Expand All @@ -88,74 +89,37 @@ def _predict(self, sentences: List[str]):
id = id + 1
offset = offset + len(word)

for token_srl in output.srl[0]:
for argument in token_srl['arguments']:
for token_srl in output.srl[idx]:
for (argument, text, start, end) in token_srl['arguments']:
# print(token_srl['index'], token_srl['predicate'], argument)
text = argument[1]
start = argument[2]
offset = words[start]['offset']
words[token_srl['index']]['roles'].append({
'text': text,
'offset': offset,
'length': len(text),
'type': argument[0]
'type': argument
})

start = 0
for end, label in \
zip(output.sdp[0]['head'], output.sdp[0]['label']):
for end, label in zip(output.sdp[idx]['head'], output.sdp[idx]['label']):
words[start]['parents'].append({'parent': end - 1, 'relate': label})
start = start + 1

nes = []
for role, text, start, end in output.ner[0]:
for role, text, start, end in output.ner[idx]:
nes.append({
'text': text,
'offset': start,
'ne': role.lower(),
'length': len(text)
})

result = {
'text': sentences[0],
'nes': nes,
'words': words
}

return result

def serve(self, port: int = 5000, n_process: int = None):
if n_process is None:
n_process = 1 if sys.platform == 'win32' else 8

fmt = LogFormatter(fmt='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', color=True)
root_logger = logging.getLogger()

console_handler = logging.StreamHandler()
file_handler = logging.FileHandler('server.log')

console_handler.setFormatter(fmt)
file_handler.setFormatter(fmt)

root_logger.addHandler(console_handler)
root_logger.addHandler(file_handler)

app_log.setLevel(logging.INFO)
gen_log.setLevel(logging.INFO)
access_log.setLevel(logging.INFO)

# app_log.info("Model is loading...")
app_log.info("Model Has Been Loaded!")

app = Application([
(r"/.*", LTPHandler, dict(ltp=self))
])

server = HTTPServer(app)
server.bind(port)
server.start(n_process)
ioloop.IOLoop.instance().start()

result.append(
{
'text': sentence,
'nes': nes,
'words': words
}
)

if __name__ == '__main__':
Fire(Server)
return result

0 comments on commit 1ee20dd

Please sign in to comment.