-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
61 lines (47 loc) · 1.83 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from flask import Flask, request, jsonify
from flask_cors import CORS
from haystack.nodes import DensePassageRetriever, FARMReader
from haystack.pipelines import ExtractiveQAPipeline
from haystack.document_stores import FAISSDocumentStore
app = Flask(__name__)
cors = CORS(app)
pgr_document_store = FAISSDocumentStore.load("pgr_test")
pgt_document_store = FAISSDocumentStore.load("PGT")
ug_document_store = FAISSDocumentStore.load("UG")
retriever = DensePassageRetriever(
document_store=ug_document_store,
query_embedding_model="vblagoje/dpr-question_encoder-single-lfqa-wiki",
passage_embedding_model="vblagoje/dpr-ctx_encoder-single-lfqa-wiki",
)
reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2", context_window_size=500)
pipe = ExtractiveQAPipeline(reader, retriever)
@app.route('/query', methods=['POST'])
def qa():
programme = request.json["programme"]
if programme == "pgr":
retriever.document_store = pgr_document_store
elif programme == "pgt":
retriever.document_store = pgt_document_store
elif programme == "ug":
retriever.document_store = ug_document_store
prediction = pipe.run(
query=request.json["query"], params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 4}})
answers = {
"a1":{
"answer": prediction["answers"][0].answer,
"context": prediction["answers"][0].context
},
"a2":{
"answer": prediction["answers"][1].answer,
"context": prediction["answers"][1].context
},
"a3":{
"answer": prediction["answers"][2].answer,
"context": prediction["answers"][2].context
},
"a4":{
"answer": prediction["answers"][3].answer,
"context": prediction["answers"][3].context
}
}
return jsonify(answers)