diff --git a/README.md b/README.md new file mode 100644 index 0000000..34ec0fa --- /dev/null +++ b/README.md @@ -0,0 +1,5 @@ +API that returns whether a text input is a question or statement. + +Uses best practices: +- https://huggingface.co/docs/transformers/pipeline_webserver#using-pipelines-for-a-webserver +- https://fastapi.tiangolo.com/advanced/events/#async-context-manager \ No newline at end of file diff --git a/app/main.py b/app/main.py index ea8042a..04731be 100644 --- a/app/main.py +++ b/app/main.py @@ -1,11 +1,22 @@ -from transformers import pipeline +import asyncio +from contextlib import asynccontextmanager + +from fastapi import FastAPI, Request from pydantic import BaseModel -from fastapi import FastAPI +from transformers import pipeline + +@asynccontextmanager +async def lifespan(app: FastAPI): + q = asyncio.Queue() + app.model_queue = q + task = asyncio.create_task(server_loop(q)) + yield + task.cancel() -app = FastAPI() + +app = FastAPI(lifespan=lifespan) model = "shahrukhx01/question-vs-statement-classifier" -pipe = pipeline("text-classification", model=model) custom_labels = {"LABEL_0": "STATEMENT", "LABEL_1": "QUESTION"} @@ -16,9 +27,20 @@ class Payload(BaseModel): @app.post("/test") -async def test(payload: Payload): - result = pipe(payload.text)[0] +async def test(payload: Payload, request: Request): + response_q = asyncio.Queue() + await request.app.model_queue.put((payload.text, response_q)) + output = await response_q.get() + result = output[0] # Customize the label result["label"] = custom_labels.get(result["label"], result["label"]) return result + + +async def server_loop(q): + pipe = pipeline("text-classification", model=model) + while True: + (string, response_q) = await q.get() + out = pipe(string) + await response_q.put(out)