Skip to content

Commit

Permalink
Merge pull request #64 from CambioML/flex
Browse files Browse the repository at this point in the history
Refactoring RAG Components
  • Loading branch information
Cambio ML authored Sep 20, 2023
2 parents 471f0f8 + 2742109 commit 76dde5c
Show file tree
Hide file tree
Showing 28 changed files with 1,414 additions and 99 deletions.
7 changes: 4 additions & 3 deletions example/flex/flex_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pykoi.retrieval import RetrievalFactory
from pykoi.retrieval import VectorDbFactory
from pykoi.component import Chatbot, Dashboard, RetrievalQA
from pykoi.chat import RAGDatabase


load_dotenv()
Expand All @@ -32,9 +33,9 @@ def main(**kargs):
)

# retrieval, chatbot, and dashboard pykoi components
retriever = RetrievalQA(retrieval_model=retrieval_model, vector_db=vector_db)
chatbot = Chatbot(None, feedback="vote", is_retrieval=True)
dashboard = Dashboard(QuestionAnswerDatabase())
retriever = RetrievalQA(retrieval_model=retrieval_model, vector_db=vector_db, feedback="rag")
chatbot = Chatbot(None, feedback="rag", is_retrieval=True)
dashboard = Dashboard(RAGDatabase(), feedback="rag")

############################################################
# Starting the application and retrieval qa as a component #
Expand Down
1 change: 0 additions & 1 deletion pykoi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from pykoi.application import Application


__version__ = "0.0.6"
98 changes: 90 additions & 8 deletions pykoi/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pykoi.interactives.chatbot import Chatbot
from pykoi.telemetry.telemetry import Telemetry
from pykoi.telemetry.events import AppStartEvent, AppStopEvent
from pykoi.chat.db.constants import QA_LIST_SEPARATOR


oauth_scheme = HTTPBasic()
Expand All @@ -26,6 +27,17 @@ class UpdateQATable(BaseModel):
id: int
vote_status: str

class UpdateRAGTable(BaseModel):
id: int
vote_status: str

class UpdateQATableAnswer(BaseModel):
id: int
new_answer: str

class UpdateRAGTableAnswer(BaseModel):
id: int
new_answer: str

class RankingTableUpdate(BaseModel):
question: str
Expand Down Expand Up @@ -209,13 +221,28 @@ async def update_qa_table(
user: Union[None, UserInDB] = Depends(self.get_auth_dependency()),
):
try:
print("updating QA vote")
component["component"].database.update_vote_status(
request_body.id, request_body.vote_status
)
return {"log": "Table updated", "status": "200"}
except Exception as ex:
return {"log": f"Table update failed: {ex}", "status": "500"}

@app.post("/chat/qa_table/update_answer")
async def update_qa_table_response(
request_body: UpdateQATableAnswer,
user: Union[None, UserInDB] = Depends(self.get_auth_dependency()),
):
try:
component["component"].database.update_answer(
request_body.id, request_body.new_answer
)
print("/chat/qa_table/update_answer", request_body.id, request_body.new_answer)
return {"log": "Table response updated", "new_answer": request_body.new_answer, "status": "200"}
except Exception as ex:
return {"log": f"Table update failed: {ex}", "status": "500"}

@app.get("/chat/qa_table/close")
async def close_qa_table(
user: Union[None, UserInDB] = Depends(self.get_auth_dependency())
Expand Down Expand Up @@ -271,6 +298,51 @@ async def retrieve_ranking_table(
except Exception as ex:
return {"log": f"Table retrieval failed: {ex}", "status": "500"}

@app.post("/chat/rag_table/update")
async def update_rag_table(
request_body: UpdateRAGTable,
user: Union[None, UserInDB] = Depends(self.get_auth_dependency()),
):
try:
print("updating RAG vote")
component["component"].database.update_vote_status(
request_body.id, request_body.vote_status
)
return {"log": "Table updated", "status": "200"}
except Exception as ex:
return {"log": f"Table update failed: {ex}", "status": "500"}

@app.post("/chat/rag_table/update_answer")
async def update_rag_table_response(
request_body: UpdateRAGTableAnswer,
user: Union[None, UserInDB] = Depends(self.get_auth_dependency()),
):
try:
component["component"].database.update_answer(
request_body.id, request_body.new_answer
)
print("/chat/rag_table/update_answer", request_body.id, request_body.new_answer)
return {"log": "Table response updated", "new_answer": request_body.new_answer, "status": "200"}
except Exception as ex:
return {"log": f"Table update failed: {ex}", "status": "500"}

@app.get("/chat/rag_table/retrieve")
async def retrieve_rag_table(
user: Union[None, UserInDB] = Depends(self.get_auth_dependency())
):
try:
rows = component["component"].database.retrieve_all_question_answers()
modified_rows = []
for row in rows:
row_list = list(row) # Convert the tuple to a list
row_list[5] = row_list[5].split(QA_LIST_SEPARATOR)
row_list[6] = row_list[6].split(QA_LIST_SEPARATOR)
row_list[7] = row_list[7].split(QA_LIST_SEPARATOR)
modified_rows.append(row_list) # Append the modified list to the new list
return {"rows": modified_rows, "log": "RAG Table retrieved", "status": "200"}
except Exception as ex:
return {"log": f"Table retrieval failed: {ex}", "status": "500"}

def create_feedback_route(self, app: FastAPI, component: Dict[str, Any]):
"""
Create feedback routes for the application.
Expand All @@ -286,7 +358,14 @@ async def retrieve_qa_table(
):
try:
rows = component["component"].database.retrieve_all_question_answers()
return {"rows": rows, "log": "Table retrieved", "status": "200"}
modified_rows = []
for row in rows:
row_list = list(row) # Convert the tuple to a list
row_list[5] = row_list[5].split(QA_LIST_SEPARATOR)
row_list[6] = row_list[6].split(QA_LIST_SEPARATOR)
row_list[7] = row_list[7].split(QA_LIST_SEPARATOR)
modified_rows.append(row_list) # Append the modified list to the new list
return {"rows": modified_rows, "log": "Table retrieved", "status": "200"}
except Exception as ex:
return {"log": f"Table retrieval failed: {ex}", "status": "500"}

Expand Down Expand Up @@ -489,16 +568,19 @@ async def inference(
print("[/retrieval]: model inference.....", request_body.prompt)
component["component"].retrieval_model.re_init(request_body.file_names)
output = component["component"].retrieval_model.run_with_return_source_documents({"query": request_body.prompt})
id = component["component"].database.insert_question_answer(
request_body.prompt, output["result"]
)
print('output', output, output["result"])
if output["source_documents"] == []:
source = "N/A"
source_content = "N/A"
source = ["N/A"]
source_content = ["N/A"]
else:
source = output["source_documents"][0].metadata.get('file_name', 'No file name found')
source_content = "1. " + output["source_documents"][0].page_content + "\n2. " + output["source_documents"][1].page_content
source = []
source_content = []
for source_document in output["source_documents"]:
source.append(source_document.metadata.get('file_name', 'No file name found'))
source_content.append(source_document.page_content)
id = component["component"].database.insert_question_answer(
request_body.prompt, output["result"], request_body.file_names, source, source_content
)
return {
"id": id,
"log": "Inference complete",
Expand Down
1 change: 1 addition & 0 deletions pykoi/chat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from pykoi.chat.llm.model_factory import ModelFactory
from pykoi.chat.db.qa_database import QuestionAnswerDatabase
from pykoi.chat.db.ranking_database import RankingDatabase
from pykoi.chat.db.rag_database import RAGDatabase
3 changes: 3 additions & 0 deletions pykoi/chat/db/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@
RANKING_CSV_HEADER_UP_RANKING_ANSWER,
RANKING_CSV_HEADER_LOW_RANKING_ANSWER,
)

# list separator
QA_LIST_SEPARATOR = "||"
Loading

0 comments on commit 76dde5c

Please sign in to comment.