Skip to content

Commit

Permalink
add separate rag_database to be used with RAGChatbot
Browse files Browse the repository at this point in the history
  • Loading branch information
jojortz committed Sep 14, 2023
1 parent f5a4599 commit 2e16378
Show file tree
Hide file tree
Showing 8 changed files with 341 additions and 9 deletions.
6 changes: 3 additions & 3 deletions example/flex/flex_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def main(**kargs):
)

# retrieval, chatbot, and dashboard pykoi components
retriever = pykoi.RetrievalQA(retrieval_model=retrieval_model, vector_db=vector_db)
chatbot = pykoi.Chatbot(None, feedback="vote", is_retrieval=True)
dashboard = pykoi.Dashboard(pykoi.QuestionAnswerDatabase())
retriever = pykoi.RetrievalQA(retrieval_model=retrieval_model, vector_db=vector_db, feedback="rag")
chatbot = pykoi.Chatbot(None, feedback="rag", is_retrieval=True)
dashboard = pykoi.Dashboard(pykoi.RAGDatabase(), feedback="rag")

############################################################
# Starting the application and retrieval qa as a component #
Expand Down
1 change: 1 addition & 0 deletions pykoi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from pykoi.application import Application
from pykoi.chat.db.qa_database import QuestionAnswerDatabase
from pykoi.chat.db.rag_database import RAGDatabase
from pykoi.chat.db.ranking_database import RankingDatabase
from pykoi.chat.llm.abs_llm import AbsLlm
from pykoi.chat.llm.model_factory import ModelFactory
Expand Down
87 changes: 83 additions & 4 deletions pykoi/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pykoi.interactives.chatbot import Chatbot
from pykoi.telemetry.telemetry import Telemetry
from pykoi.telemetry.events import AppStartEvent, AppStopEvent
from pykoi.chat.db.constants import QA_LIST_SEPARATOR


oauth_scheme = HTTPBasic()
Expand All @@ -27,6 +28,17 @@ class UpdateQATable(BaseModel):
id: int
vote_status: str

class UpdateRAGTable(BaseModel):
id: int
vote_status: str

class UpdateQATableAnswer(BaseModel):
id: int
new_answer: str

class UpdateRAGTableAnswer(BaseModel):
id: int
new_answer: str

class RankingTableUpdate(BaseModel):
question: str
Expand Down Expand Up @@ -207,13 +219,28 @@ async def update_qa_table(
user: Union[None, UserInDB] = Depends(self.get_auth_dependency()),
):
try:
print("updating QA vote")
component["component"].database.update_vote_status(
request_body.id, request_body.vote_status
)
return {"log": "Table updated", "status": "200"}
except Exception as ex:
return {"log": f"Table update failed: {ex}", "status": "500"}

@app.post("/chat/qa_table/update_answer")
async def update_qa_table_response(
request_body: UpdateQATableAnswer,
user: Union[None, UserInDB] = Depends(self.get_auth_dependency()),
):
try:
component["component"].database.update_answer(
request_body.id, request_body.new_answer
)
print("/chat/qa_table/update_answer", request_body.id, request_body.new_answer)
return {"log": "Table response updated", "new_answer": request_body.new_answer, "status": "200"}
except Exception as ex:
return {"log": f"Table update failed: {ex}", "status": "500"}

@app.get("/chat/qa_table/close")
async def close_qa_table(
user: Union[None, UserInDB] = Depends(self.get_auth_dependency())
Expand Down Expand Up @@ -269,6 +296,51 @@ async def retrieve_ranking_table(
except Exception as ex:
return {"log": f"Table retrieval failed: {ex}", "status": "500"}

@app.post("/chat/rag_table/update")
async def update_rag_table(
request_body: UpdateRAGTable,
user: Union[None, UserInDB] = Depends(self.get_auth_dependency()),
):
try:
print("updating RAG vote")
component["component"].database.update_vote_status(
request_body.id, request_body.vote_status
)
return {"log": "Table updated", "status": "200"}
except Exception as ex:
return {"log": f"Table update failed: {ex}", "status": "500"}

@app.post("/chat/rag_table/update_answer")
async def update_rag_table_response(
request_body: UpdateRAGTableAnswer,
user: Union[None, UserInDB] = Depends(self.get_auth_dependency()),
):
try:
component["component"].database.update_answer(
request_body.id, request_body.new_answer
)
print("/chat/rag_table/update_answer", request_body.id, request_body.new_answer)
return {"log": "Table response updated", "new_answer": request_body.new_answer, "status": "200"}
except Exception as ex:
return {"log": f"Table update failed: {ex}", "status": "500"}

@app.get("/chat/rag_table/retrieve")
async def retrieve_rag_table(
user: Union[None, UserInDB] = Depends(self.get_auth_dependency())
):
try:
rows = component["component"].database.retrieve_all_question_answers()
modified_rows = []
for row in rows:
row_list = list(row) # Convert the tuple to a list
row_list[5] = row_list[5].split(QA_LIST_SEPARATOR)
row_list[6] = row_list[6].split(QA_LIST_SEPARATOR)
row_list[7] = row_list[7].split(QA_LIST_SEPARATOR)
modified_rows.append(row_list) # Append the modified list to the new list
return {"rows": modified_rows, "log": "RAG Table retrieved", "status": "200"}
except Exception as ex:
return {"log": f"Table retrieval failed: {ex}", "status": "500"}

def create_feedback_route(self, app: FastAPI, component: Dict[str, Any]):
"""
Create feedback routes for the application.
Expand All @@ -284,7 +356,14 @@ async def retrieve_qa_table(
):
try:
rows = component["component"].database.retrieve_all_question_answers()
return {"rows": rows, "log": "Table retrieved", "status": "200"}
modified_rows = []
for row in rows:
row_list = list(row) # Convert the tuple to a list
row_list[5] = row_list[5].split(QA_LIST_SEPARATOR)
row_list[6] = row_list[6].split(QA_LIST_SEPARATOR)
row_list[7] = row_list[7].split(QA_LIST_SEPARATOR)
modified_rows.append(row_list) # Append the modified list to the new list
return {"rows": modified_rows, "log": "Table retrieved", "status": "200"}
except Exception as ex:
return {"log": f"Table retrieval failed: {ex}", "status": "500"}

Expand Down Expand Up @@ -487,9 +566,6 @@ async def inference(
print("[/retrieval]: model inference.....", request_body.prompt)
component["component"].retrieval_model.re_init(request_body.file_names)
output = component["component"].retrieval_model.run_with_return_source_documents({"query": request_body.prompt})
id = component["component"].database.insert_question_answer(
request_body.prompt, output["result"]
)
print('output', output, output["result"])
if output["source_documents"] == []:
source = ["N/A"]
Expand All @@ -500,6 +576,9 @@ async def inference(
for source_document in output["source_documents"]:
source.append(source_document.metadata.get('file_name', 'No file name found'))
source_content.append(source_document.page_content)
id = component["component"].database.insert_question_answer(
request_body.prompt, output["result"], request_body.file_names, source, source_content
)
return {
"id": id,
"log": "Inference complete",
Expand Down
3 changes: 3 additions & 0 deletions pykoi/chat/db/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@
RANKING_CSV_HEADER_UP_RANKING_ANSWER,
RANKING_CSV_HEADER_LOW_RANKING_ANSWER,
)

# list separator
QA_LIST_SEPARATOR = "||"
Loading

0 comments on commit 2e16378

Please sign in to comment.