From bed82f01efdb99c1e1c2076d9a498a4d21575c31 Mon Sep 17 00:00:00 2001 From: Jojo Ortiz Date: Wed, 27 Sep 2023 10:46:31 -0700 Subject: [PATCH] add endpoints to save_to_csv for rag_table and qa_table --- pykoi/application.py | 39 +++++++++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/pykoi/application.py b/pykoi/application.py index 2b27e8f..ac021c8 100644 --- a/pykoi/application.py +++ b/pykoi/application.py @@ -17,7 +17,7 @@ from pykoi.interactives.chatbot import Chatbot from pykoi.telemetry.telemetry import Telemetry from pykoi.telemetry.events import AppStartEvent, AppStopEvent -from pykoi.chat.db.constants import QA_LIST_SEPARATOR +from pykoi.chat.db.constants import RAG_LIST_SEPARATOR oauth_scheme = HTTPBasic() @@ -44,26 +44,27 @@ class RankingTableUpdate(BaseModel): up_ranking_answer: str low_ranking_answer: str - class InferenceRankingTable(BaseModel): n: Optional[int] = 2 - class ModelAnswer(BaseModel): model: str qid: int rank: int answer: str - class ComparatorInsertRequest(BaseModel): data: List[ModelAnswer] - class RetrievalNewMessage(BaseModel): prompt: str file_names: List[str] +class QATableToCSV(BaseModel): + file_name: str + +class RAGTableToCSV(BaseModel): + file_name: str class UserInDB: def __init__(self, username: str, hashed_password: str): @@ -243,6 +244,17 @@ async def update_qa_table_response( except Exception as ex: return {"log": f"Table update failed: {ex}", "status": "500"} + @app.post("/chat/qa_table/save_to_csv") + async def save_qa_table_to_csv( + request_body: QATableToCSV, + user: Union[None, UserInDB] = Depends(self.get_auth_dependency()), + ): + try: + component["component"].database.save_to_csv(request_body.file_name) + return {"log": f"Saved to {request_body.file_name}.csv", "status": "200"} + except Exception as ex: + return {"log": f"Save to CSV failed: {ex}", "status": "500"} + @app.get("/chat/qa_table/close") async def close_qa_table( user: Union[None, UserInDB] = Depends(self.get_auth_dependency()) @@ -335,14 +347,25 @@ async def retrieve_rag_table( modified_rows = [] for row in rows: row_list = list(row) # Convert the tuple to a list - row_list[5] = row_list[5].split(QA_LIST_SEPARATOR) - row_list[6] = row_list[6].split(QA_LIST_SEPARATOR) - row_list[7] = row_list[7].split(QA_LIST_SEPARATOR) + row_list[5] = row_list[5].split(RAG_LIST_SEPARATOR) + row_list[6] = row_list[6].split(RAG_LIST_SEPARATOR) + row_list[7] = row_list[7].split(RAG_LIST_SEPARATOR) modified_rows.append(row_list) # Append the modified list to the new list return {"rows": modified_rows, "log": "RAG Table retrieved", "status": "200"} except Exception as ex: return {"log": f"Table retrieval failed: {ex}", "status": "500"} + @app.post("/chat/rag_table/save_to_csv") + async def save_rag_table_to_csv( + request_body: RAGTableToCSV, + user: Union[None, UserInDB] = Depends(self.get_auth_dependency()), + ): + try: + component["component"].database.save_to_csv(request_body.file_name) + return {"log": f"Saved to {request_body.file_name}.csv", "status": "200"} + except Exception as ex: + return {"log": f"Save to CSV failed: {ex}", "status": "500"} + def create_feedback_route(self, app: FastAPI, component: Dict[str, Any]): """ Create feedback routes for the application.