diff --git a/pykoi/application.py b/pykoi/application.py index bc69c51..37201ac 100644 --- a/pykoi/application.py +++ b/pykoi/application.py @@ -2,9 +2,7 @@ import asyncio import os import re -import socket import subprocess -import threading import time from datetime import datetime @@ -15,9 +13,7 @@ from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel -from pyngrok import ngrok from starlette.middleware.cors import CORSMiddleware -from pykoi.interactives.chatbot import Chatbot from pykoi.telemetry.telemetry import Telemetry from pykoi.telemetry.events import AppStartEvent, AppStopEvent from pykoi.chat.db.constants import RAG_LIST_SEPARATOR @@ -30,48 +26,60 @@ class UpdateQATable(BaseModel): id: int vote_status: str + class UpdateRAGTable(BaseModel): id: int vote_status: str + class UpdateQATableAnswer(BaseModel): id: int new_answer: str + class UpdateRAGTableAnswer(BaseModel): id: int new_answer: str + class RankingTableUpdate(BaseModel): question: str up_ranking_answer: str low_ranking_answer: str + class InferenceRankingTable(BaseModel): n: Optional[int] = 2 + class ModelAnswer(BaseModel): model: str qid: int rank: int answer: str + class ComparatorInsertRequest(BaseModel): data: List[ModelAnswer] + class RetrievalNewMessage(BaseModel): prompt: str file_names: List[str] + class QATableToCSV(BaseModel): file_name: str + class RAGTableToCSV(BaseModel): file_name: str + class ComparatorTableToCSV(BaseModel): file_name: str + class UserInDB: def __init__(self, username: str, hashed_password: str): self.username = username @@ -97,7 +105,7 @@ def __init__( Initialize the Application. Args: - share (bool, optional): If True, the application will be shared via ngrok. Defaults to False. + share (bool, optional): If True, the application will be shared via localhost.run. Defaults to False. debug (bool, optional): If True, the application will run in debug mode. Defaults to False. username (str, optional): The username for authentication. Defaults to None. password (str, optional): The password for authentication. Defaults to None. @@ -245,8 +253,16 @@ async def update_qa_table_response( component["component"].database.update_answer( request_body.id, request_body.new_answer ) - print("/chat/qa_table/update_answer", request_body.id, request_body.new_answer) - return {"log": "Table response updated", "new_answer": request_body.new_answer, "status": "200"} + print( + "/chat/qa_table/update_answer", + request_body.id, + request_body.new_answer, + ) + return { + "log": "Table response updated", + "new_answer": request_body.new_answer, + "status": "200", + } except Exception as ex: return {"log": f"Table update failed: {ex}", "status": "500"} @@ -257,7 +273,10 @@ async def save_qa_table_to_csv( ): try: component["component"].database.save_to_csv(request_body.file_name) - return {"log": f"Saved to {request_body.file_name}.csv", "status": "200"} + return { + "log": f"Saved to {request_body.file_name}.csv", + "status": "200", + } except Exception as ex: return {"log": f"Save to CSV failed: {ex}", "status": "500"} @@ -339,8 +358,16 @@ async def update_rag_table_response( component["component"].database.update_answer( request_body.id, request_body.new_answer ) - print("/chat/rag_table/update_answer", request_body.id, request_body.new_answer) - return {"log": "Table response updated", "new_answer": request_body.new_answer, "status": "200"} + print( + "/chat/rag_table/update_answer", + request_body.id, + request_body.new_answer, + ) + return { + "log": "Table response updated", + "new_answer": request_body.new_answer, + "status": "200", + } except Exception as ex: return {"log": f"Table update failed: {ex}", "status": "500"} @@ -356,8 +383,14 @@ async def retrieve_rag_table( row_list[5] = row_list[5].split(RAG_LIST_SEPARATOR) row_list[6] = row_list[6].split(RAG_LIST_SEPARATOR) row_list[7] = row_list[7].split(RAG_LIST_SEPARATOR) - modified_rows.append(row_list) # Append the modified list to the new list - return {"rows": modified_rows, "log": "RAG Table retrieved", "status": "200"} + modified_rows.append( + row_list + ) # Append the modified list to the new list + return { + "rows": modified_rows, + "log": "RAG Table retrieved", + "status": "200", + } except Exception as ex: return {"log": f"Table retrieval failed: {ex}", "status": "500"} @@ -368,7 +401,10 @@ async def save_rag_table_to_csv( ): try: component["component"].database.save_to_csv(request_body.file_name) - return {"log": f"Saved to {request_body.file_name}.csv", "status": "200"} + return { + "log": f"Saved to {request_body.file_name}.csv", + "status": "200", + } except Exception as ex: return {"log": f"Save to CSV failed: {ex}", "status": "500"} @@ -464,7 +500,9 @@ async def retrieve_comparator( user: Union[None, UserInDB] = Depends(self.get_auth_dependency()) ): try: - rows = component["component"].comparator_db.retrieve_all_question_answers() + rows = component[ + "component" + ].comparator_db.retrieve_all_question_answers() data = [] for row in rows: a_id, model_name, qid, question, answer, rank, _ = row @@ -502,11 +540,13 @@ async def save_comparator_table_to_csv( try: print("Saving Comparator to CSV", request_body.file_name) component["component"].comparator_db.save_to_csv(request_body.file_name) - return {"log": f"Saved to {request_body.file_name}.csv", "status": "200"} + return { + "log": f"Saved to {request_body.file_name}.csv", + "status": "200", + } except Exception as ex: return {"log": f"Save to CSV failed: {ex}", "status": "500"} - def create_qa_retrieval_route(self, app: FastAPI, component: Dict[str, Any]): """ Create QA retrieval routes for the application. @@ -604,8 +644,12 @@ async def inference( try: print("[/retrieval]: model inference.....", request_body.prompt) component["component"].retrieval_model.re_init(request_body.file_names) - output = component["component"].retrieval_model.run_with_return_source_documents({"query": request_body.prompt}) - print('output', output, output["result"]) + output = component[ + "component" + ].retrieval_model.run_with_return_source_documents( + {"query": request_body.prompt} + ) + print("output", output, output["result"]) if output["source_documents"] == []: source = ["N/A"] source_content = ["N/A"] @@ -613,10 +657,18 @@ async def inference( source = [] source_content = [] for source_document in output["source_documents"]: - source.append(source_document.metadata.get('file_name', 'No file name found')) + source.append( + source_document.metadata.get( + "file_name", "No file name found" + ) + ) source_content.append(source_document.page_content) id = component["component"].database.insert_question_answer( - request_body.prompt, output["result"], request_body.file_names, source, source_content + request_body.prompt, + output["result"], + request_body.file_names, + source, + source_content, ) return { "id": id, @@ -783,7 +835,6 @@ async def read_item( # debug mode should be set to False in production because # it will start two processes when debug mode is enabled. - # Set the ngrok tunnel if share is True start_event = AppStartEvent( start_time=time.time(), date_time=datetime.utcfromtimestamp(time.time()) ) @@ -833,115 +884,3 @@ async def read_item( duration=time.time() - start_event.start_time, ) ) - - def display(self): - """ - Run the application. - """ - print("hey2") - import nest_asyncio - - nest_asyncio.apply() - app = FastAPI() - - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - - @app.post("/token") - def login(credentials: HTTPBasicCredentials = Depends(oauth_scheme)): - user = self.authenticate_user( - self._fake_users_db, credentials.username, credentials.password - ) - if not user: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect username or password", - headers={"WWW-Authenticate": "Basic"}, - ) - return {"message": "Logged in successfully"} - - @app.get("/components") - async def get_components( - user: Union[None, UserInDB] = Depends(self.get_auth_dependency()) - ): - return JSONResponse( - [ - { - "id": component["id"], - "svelte_component": component["svelte_component"], - "props": component["props"], - } - for component in self.components - ] - ) - - def create_data_route(id: str, data_source: Any): - """ - Create data route for the application. - - Args: - id (str): The id of the data source. - data_source (Any): The data source. - """ - - @app.get(f"/data/{id}") - async def get_data( - user: Union[None, UserInDB] = Depends(self.get_auth_dependency()) - ): - data = data_source.fetch_func() - return JSONResponse(data) - - for id, data_source in self.data_sources.items(): - create_data_route(id, data_source) - - for component in self.components: - if component["svelte_component"] == "Chatbot": - self.create_chatbot_route(app, component) - if component["svelte_component"] == "Feedback": - self.create_feedback_route(app, component) - if component["svelte_component"] == "Compare": - self.create_chatbot_comparator_route(app, component) - - app.mount( - "/", - StaticFiles( - directory=os.path.join( - os.path.dirname(os.path.realpath(__file__)), "frontend/dist" - ), - html=True, - ), - name="static", - ) - - @app.get("/{path:path}") - async def read_item( - path: str, user: Union[None, UserInDB] = Depends(self.get_auth_dependency()) - ): - return {"path": path} - - # debug mode should be set to False in production because - # it will start two processes when debug mode is enabled. - - # Set the ngrok tunnel if share is True - if self._share: - public_url = ngrok.connect(self._port) - print("Public URL:", public_url) - import uvicorn - - uvicorn.run(app, host=self._host, port=self._port) - print("Stopping server...") - ngrok.disconnect(public_url) - else: - import uvicorn - - def run_uvicorn(): - uvicorn.run(app, host=self._host, port=self._port) - - t = threading.Thread(target=run_uvicorn) - t.start() - return Chatbot()(port=self._host)