Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove unused dependencies #76

Merged
merged 1 commit into from
Oct 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 72 additions & 133 deletions pykoi/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
import asyncio
import os
import re
import socket
import subprocess
import threading
import time

from datetime import datetime
Expand All @@ -15,9 +13,7 @@
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from pyngrok import ngrok
from starlette.middleware.cors import CORSMiddleware
from pykoi.interactives.chatbot import Chatbot
from pykoi.telemetry.telemetry import Telemetry
from pykoi.telemetry.events import AppStartEvent, AppStopEvent
from pykoi.chat.db.constants import RAG_LIST_SEPARATOR
Expand All @@ -30,48 +26,60 @@ class UpdateQATable(BaseModel):
id: int
vote_status: str


class UpdateRAGTable(BaseModel):
id: int
vote_status: str


class UpdateQATableAnswer(BaseModel):
id: int
new_answer: str


class UpdateRAGTableAnswer(BaseModel):
id: int
new_answer: str


class RankingTableUpdate(BaseModel):
question: str
up_ranking_answer: str
low_ranking_answer: str


class InferenceRankingTable(BaseModel):
n: Optional[int] = 2


class ModelAnswer(BaseModel):
model: str
qid: int
rank: int
answer: str


class ComparatorInsertRequest(BaseModel):
data: List[ModelAnswer]


class RetrievalNewMessage(BaseModel):
prompt: str
file_names: List[str]


class QATableToCSV(BaseModel):
file_name: str


class RAGTableToCSV(BaseModel):
file_name: str


class ComparatorTableToCSV(BaseModel):
file_name: str


class UserInDB:
def __init__(self, username: str, hashed_password: str):
self.username = username
Expand All @@ -97,7 +105,7 @@ def __init__(
Initialize the Application.
Args:
share (bool, optional): If True, the application will be shared via ngrok. Defaults to False.
share (bool, optional): If True, the application will be shared via localhost.run. Defaults to False.
debug (bool, optional): If True, the application will run in debug mode. Defaults to False.
username (str, optional): The username for authentication. Defaults to None.
password (str, optional): The password for authentication. Defaults to None.
Expand Down Expand Up @@ -245,8 +253,16 @@ async def update_qa_table_response(
component["component"].database.update_answer(
request_body.id, request_body.new_answer
)
print("/chat/qa_table/update_answer", request_body.id, request_body.new_answer)
return {"log": "Table response updated", "new_answer": request_body.new_answer, "status": "200"}
print(
"/chat/qa_table/update_answer",
request_body.id,
request_body.new_answer,
)
return {
"log": "Table response updated",
"new_answer": request_body.new_answer,
"status": "200",
}
except Exception as ex:
return {"log": f"Table update failed: {ex}", "status": "500"}

Expand All @@ -257,7 +273,10 @@ async def save_qa_table_to_csv(
):
try:
component["component"].database.save_to_csv(request_body.file_name)
return {"log": f"Saved to {request_body.file_name}.csv", "status": "200"}
return {
"log": f"Saved to {request_body.file_name}.csv",
"status": "200",
}
except Exception as ex:
return {"log": f"Save to CSV failed: {ex}", "status": "500"}

Expand Down Expand Up @@ -339,8 +358,16 @@ async def update_rag_table_response(
component["component"].database.update_answer(
request_body.id, request_body.new_answer
)
print("/chat/rag_table/update_answer", request_body.id, request_body.new_answer)
return {"log": "Table response updated", "new_answer": request_body.new_answer, "status": "200"}
print(
"/chat/rag_table/update_answer",
request_body.id,
request_body.new_answer,
)
return {
"log": "Table response updated",
"new_answer": request_body.new_answer,
"status": "200",
}
except Exception as ex:
return {"log": f"Table update failed: {ex}", "status": "500"}

Expand All @@ -356,8 +383,14 @@ async def retrieve_rag_table(
row_list[5] = row_list[5].split(RAG_LIST_SEPARATOR)
row_list[6] = row_list[6].split(RAG_LIST_SEPARATOR)
row_list[7] = row_list[7].split(RAG_LIST_SEPARATOR)
modified_rows.append(row_list) # Append the modified list to the new list
return {"rows": modified_rows, "log": "RAG Table retrieved", "status": "200"}
modified_rows.append(
row_list
) # Append the modified list to the new list
return {
"rows": modified_rows,
"log": "RAG Table retrieved",
"status": "200",
}
except Exception as ex:
return {"log": f"Table retrieval failed: {ex}", "status": "500"}

Expand All @@ -368,7 +401,10 @@ async def save_rag_table_to_csv(
):
try:
component["component"].database.save_to_csv(request_body.file_name)
return {"log": f"Saved to {request_body.file_name}.csv", "status": "200"}
return {
"log": f"Saved to {request_body.file_name}.csv",
"status": "200",
}
except Exception as ex:
return {"log": f"Save to CSV failed: {ex}", "status": "500"}

Expand Down Expand Up @@ -464,7 +500,9 @@ async def retrieve_comparator(
user: Union[None, UserInDB] = Depends(self.get_auth_dependency())
):
try:
rows = component["component"].comparator_db.retrieve_all_question_answers()
rows = component[
"component"
].comparator_db.retrieve_all_question_answers()
data = []
for row in rows:
a_id, model_name, qid, question, answer, rank, _ = row
Expand Down Expand Up @@ -502,11 +540,13 @@ async def save_comparator_table_to_csv(
try:
print("Saving Comparator to CSV", request_body.file_name)
component["component"].comparator_db.save_to_csv(request_body.file_name)
return {"log": f"Saved to {request_body.file_name}.csv", "status": "200"}
return {
"log": f"Saved to {request_body.file_name}.csv",
"status": "200",
}
except Exception as ex:
return {"log": f"Save to CSV failed: {ex}", "status": "500"}


def create_qa_retrieval_route(self, app: FastAPI, component: Dict[str, Any]):
"""
Create QA retrieval routes for the application.
Expand Down Expand Up @@ -604,19 +644,31 @@ async def inference(
try:
print("[/retrieval]: model inference.....", request_body.prompt)
component["component"].retrieval_model.re_init(request_body.file_names)
output = component["component"].retrieval_model.run_with_return_source_documents({"query": request_body.prompt})
print('output', output, output["result"])
output = component[
"component"
].retrieval_model.run_with_return_source_documents(
{"query": request_body.prompt}
)
print("output", output, output["result"])
if output["source_documents"] == []:
source = ["N/A"]
source_content = ["N/A"]
else:
source = []
source_content = []
for source_document in output["source_documents"]:
source.append(source_document.metadata.get('file_name', 'No file name found'))
source.append(
source_document.metadata.get(
"file_name", "No file name found"
)
)
source_content.append(source_document.page_content)
id = component["component"].database.insert_question_answer(
request_body.prompt, output["result"], request_body.file_names, source, source_content
request_body.prompt,
output["result"],
request_body.file_names,
source,
source_content,
)
return {
"id": id,
Expand Down Expand Up @@ -783,7 +835,6 @@ async def read_item(
# debug mode should be set to False in production because
# it will start two processes when debug mode is enabled.

# Set the ngrok tunnel if share is True
start_event = AppStartEvent(
start_time=time.time(), date_time=datetime.utcfromtimestamp(time.time())
)
Expand Down Expand Up @@ -833,115 +884,3 @@ async def read_item(
duration=time.time() - start_event.start_time,
)
)

def display(self):
"""
Run the application.
"""
print("hey2")
import nest_asyncio

nest_asyncio.apply()
app = FastAPI()

app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

@app.post("/token")
def login(credentials: HTTPBasicCredentials = Depends(oauth_scheme)):
user = self.authenticate_user(
self._fake_users_db, credentials.username, credentials.password
)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Basic"},
)
return {"message": "Logged in successfully"}

@app.get("/components")
async def get_components(
user: Union[None, UserInDB] = Depends(self.get_auth_dependency())
):
return JSONResponse(
[
{
"id": component["id"],
"svelte_component": component["svelte_component"],
"props": component["props"],
}
for component in self.components
]
)

def create_data_route(id: str, data_source: Any):
"""
Create data route for the application.
Args:
id (str): The id of the data source.
data_source (Any): The data source.
"""

@app.get(f"/data/{id}")
async def get_data(
user: Union[None, UserInDB] = Depends(self.get_auth_dependency())
):
data = data_source.fetch_func()
return JSONResponse(data)

for id, data_source in self.data_sources.items():
create_data_route(id, data_source)

for component in self.components:
if component["svelte_component"] == "Chatbot":
self.create_chatbot_route(app, component)
if component["svelte_component"] == "Feedback":
self.create_feedback_route(app, component)
if component["svelte_component"] == "Compare":
self.create_chatbot_comparator_route(app, component)

app.mount(
"/",
StaticFiles(
directory=os.path.join(
os.path.dirname(os.path.realpath(__file__)), "frontend/dist"
),
html=True,
),
name="static",
)

@app.get("/{path:path}")
async def read_item(
path: str, user: Union[None, UserInDB] = Depends(self.get_auth_dependency())
):
return {"path": path}

# debug mode should be set to False in production because
# it will start two processes when debug mode is enabled.

# Set the ngrok tunnel if share is True
if self._share:
public_url = ngrok.connect(self._port)
print("Public URL:", public_url)
import uvicorn

uvicorn.run(app, host=self._host, port=self._port)
print("Stopping server...")
ngrok.disconnect(public_url)
else:
import uvicorn

def run_uvicorn():
uvicorn.run(app, host=self._host, port=self._port)

t = threading.Thread(target=run_uvicorn)
t.start()
return Chatbot()(port=self._host)