-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CSV Download frontend and backend #70
Changes from all commits
3f43606
921f90c
588b44a
0bf1e8a
cb265d6
bed82f0
4a99c78
bd5c4e6
281830d
41f7265
9f046af
c940692
c98fe79
6a12dff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,26 +2,34 @@ | |
Demo for the chatbot application using multiple OpenAI models. | ||
|
||
- Prerequisites: | ||
To run this jupyter notebook, you need a `pykoi` environment with the `rag` option. | ||
You can follow [the installation guide](https://github.com/CambioML/pykoi/tree/install#option-1-rag-cpu) | ||
to set up the environment. | ||
To run this jupyter notebook, you need a `pykoi` environment with the `rag` option. | ||
You can follow [the installation guide](https://github.com/CambioML/pykoi/tree/install#option-1-rag-cpu) | ||
to set up the environment. | ||
- Run the demo: | ||
1. Enter your OpenAI API key in the `api_key` below. | ||
2. On terminal and `~/pykoi` directory, run | ||
1. Enter your OpenAI API key a .env file in the `~/pykoi` directory with the name OPEN_API_KEY, e.g. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: same |
||
``` | ||
OPENAI_API_KEY=your_api_key | ||
``` | ||
2. On terminal and `~/pykoi` directory, run | ||
``` | ||
python -m example.comparator.demo_model_comparator_cpu_openai | ||
``` | ||
""" | ||
|
||
import os | ||
|
||
from dotenv import load_dotenv | ||
|
||
from pykoi import Application | ||
from pykoi.chat import ModelFactory | ||
from pykoi.component import Compare | ||
|
||
|
||
########################################################## | ||
# Creating an OpenAI model (requires an OpenAI API key) # | ||
########################################################## | ||
# enter openai api key here | ||
api_key = "" | ||
load_dotenv() | ||
api_key = os.getenv("OPENAI_API_KEY") | ||
|
||
# Creating an OpenAI model | ||
openai_model_1 = ModelFactory.create_model( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,7 @@ | |
from pykoi.interactives.chatbot import Chatbot | ||
from pykoi.telemetry.telemetry import Telemetry | ||
from pykoi.telemetry.events import AppStartEvent, AppStopEvent | ||
from pykoi.chat.db.constants import QA_LIST_SEPARATOR | ||
from pykoi.chat.db.constants import RAG_LIST_SEPARATOR | ||
|
||
|
||
oauth_scheme = HTTPBasic() | ||
|
@@ -47,26 +47,30 @@ class RankingTableUpdate(BaseModel): | |
up_ranking_answer: str | ||
low_ranking_answer: str | ||
|
||
|
||
class InferenceRankingTable(BaseModel): | ||
n: Optional[int] = 2 | ||
|
||
|
||
class ModelAnswer(BaseModel): | ||
model: str | ||
qid: int | ||
rank: int | ||
answer: str | ||
|
||
|
||
class ComparatorInsertRequest(BaseModel): | ||
data: List[ModelAnswer] | ||
|
||
|
||
class RetrievalNewMessage(BaseModel): | ||
prompt: str | ||
file_names: List[str] | ||
|
||
class QATableToCSV(BaseModel): | ||
file_name: str | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: make sure you run |
||
class RAGTableToCSV(BaseModel): | ||
file_name: str | ||
|
||
class ComparatorTableToCSV(BaseModel): | ||
file_name: str | ||
|
||
class UserInDB: | ||
def __init__(self, username: str, hashed_password: str): | ||
|
@@ -246,6 +250,17 @@ async def update_qa_table_response( | |
except Exception as ex: | ||
return {"log": f"Table update failed: {ex}", "status": "500"} | ||
|
||
@app.post("/chat/qa_table/save_to_csv") | ||
async def save_qa_table_to_csv( | ||
request_body: QATableToCSV, | ||
user: Union[None, UserInDB] = Depends(self.get_auth_dependency()), | ||
): | ||
try: | ||
component["component"].database.save_to_csv(request_body.file_name) | ||
return {"log": f"Saved to {request_body.file_name}.csv", "status": "200"} | ||
except Exception as ex: | ||
return {"log": f"Save to CSV failed: {ex}", "status": "500"} | ||
|
||
@app.get("/chat/qa_table/close") | ||
async def close_qa_table( | ||
user: Union[None, UserInDB] = Depends(self.get_auth_dependency()) | ||
|
@@ -338,14 +353,25 @@ async def retrieve_rag_table( | |
modified_rows = [] | ||
for row in rows: | ||
row_list = list(row) # Convert the tuple to a list | ||
row_list[5] = row_list[5].split(QA_LIST_SEPARATOR) | ||
row_list[6] = row_list[6].split(QA_LIST_SEPARATOR) | ||
row_list[7] = row_list[7].split(QA_LIST_SEPARATOR) | ||
row_list[5] = row_list[5].split(RAG_LIST_SEPARATOR) | ||
row_list[6] = row_list[6].split(RAG_LIST_SEPARATOR) | ||
row_list[7] = row_list[7].split(RAG_LIST_SEPARATOR) | ||
modified_rows.append(row_list) # Append the modified list to the new list | ||
return {"rows": modified_rows, "log": "RAG Table retrieved", "status": "200"} | ||
except Exception as ex: | ||
return {"log": f"Table retrieval failed: {ex}", "status": "500"} | ||
|
||
@app.post("/chat/rag_table/save_to_csv") | ||
async def save_rag_table_to_csv( | ||
request_body: RAGTableToCSV, | ||
user: Union[None, UserInDB] = Depends(self.get_auth_dependency()), | ||
): | ||
try: | ||
component["component"].database.save_to_csv(request_body.file_name) | ||
return {"log": f"Saved to {request_body.file_name}.csv", "status": "200"} | ||
except Exception as ex: | ||
return {"log": f"Save to CSV failed: {ex}", "status": "500"} | ||
|
||
def create_feedback_route(self, app: FastAPI, component: Dict[str, Any]): | ||
""" | ||
Create feedback routes for the application. | ||
|
@@ -438,17 +464,19 @@ async def retrieve_comparator( | |
user: Union[None, UserInDB] = Depends(self.get_auth_dependency()) | ||
): | ||
try: | ||
rows = component["component"].comparator_db.retrieve_all() | ||
rows = component["component"].comparator_db.retrieve_all_question_answers() | ||
data = [] | ||
for row in rows: | ||
_, model_name, qid, rank, answer, _ = row | ||
a_id, model_name, qid, question, answer, rank, _ = row | ||
|
||
data.append( | ||
{ | ||
"id": a_id, | ||
"model": model_name, | ||
"qid": qid, | ||
"rank": rank, | ||
"question": question, | ||
"answer": answer, | ||
"rank": rank, | ||
} | ||
) | ||
return {"data": data, "log": "Table retrieved", "status": "200"} | ||
|
@@ -466,6 +494,19 @@ async def close_comparator( | |
except Exception as ex: | ||
return {"log": f"Table close failed: {ex}", "status": "500"} | ||
|
||
@app.post("/chat/comparator/db/save_to_csv") | ||
async def save_comparator_table_to_csv( | ||
request_body: ComparatorTableToCSV, | ||
user: Union[None, UserInDB] = Depends(self.get_auth_dependency()), | ||
): | ||
try: | ||
print("Saving Comparator to CSV", request_body.file_name) | ||
component["component"].comparator_db.save_to_csv(request_body.file_name) | ||
return {"log": f"Saved to {request_body.file_name}.csv", "status": "200"} | ||
except Exception as ex: | ||
return {"log": f"Save to CSV failed: {ex}", "status": "500"} | ||
|
||
|
||
def create_qa_retrieval_route(self, app: FastAPI, component: Dict[str, Any]): | ||
""" | ||
Create QA retrieval routes for the application. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,5 +26,45 @@ | |
RANKING_CSV_HEADER_LOW_RANKING_ANSWER, | ||
) | ||
|
||
# RAG table | ||
RAG_CSV_HEADER_ID = "ID" | ||
RAG_CSV_HEADER_QUESTION = "Question" | ||
RAG_CSV_HEADER_ANSWER = "Answer" | ||
RAG_CSV_HEADER_EDITED = "Edited Answer" | ||
RAG_CSV_HEADER_VOTE_STATUS = "Vote Status" | ||
RAG_CSV_HEADER_RAG_SOURCES = "RAG Sources" | ||
RAG_CSV_HEADER_SOURCE = "Source" | ||
RAG_CSV_HEADER_SOURCE_CONTENT = "Source Content" | ||
RAG_CSV_HEADER_TIMESTAMP = "Timestamp" | ||
RAG_CSV_HEADER = ( | ||
RAG_CSV_HEADER_ID, | ||
RAG_CSV_HEADER_QUESTION, | ||
RAG_CSV_HEADER_ANSWER, | ||
RAG_CSV_HEADER_EDITED, | ||
RAG_CSV_HEADER_VOTE_STATUS, | ||
RAG_CSV_HEADER_RAG_SOURCES, | ||
RAG_CSV_HEADER_SOURCE, | ||
RAG_CSV_HEADER_SOURCE_CONTENT, | ||
RAG_CSV_HEADER_TIMESTAMP, | ||
) | ||
|
||
# list separator | ||
QA_LIST_SEPARATOR = "||" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: run |
||
RAG_LIST_SEPARATOR = "||" | ||
|
||
# Comparator table | ||
COMPARATOR_CSV_HEADER_ID = "ID" | ||
COMPARATOR_CSV_HEADER_MODEL = "Model" | ||
COMPARATOR_CSV_HEADER_QID = "Question ID" | ||
COMPARATOR_CSV_HEADER_QUESTION = "Question" | ||
COMPARATOR_CSV_HEADER_ANSWER = "Answer" | ||
COMPARATOR_CSV_HEADER_RANK = "Rank" | ||
COMPARATOR_CSV_HEADER_TIMESTAMP = "Timestamp" | ||
COMPARATOR_CSV_HEADER = ( | ||
COMPARATOR_CSV_HEADER_ID, | ||
COMPARATOR_CSV_HEADER_MODEL, | ||
COMPARATOR_CSV_HEADER_QID, | ||
COMPARATOR_CSV_HEADER_QUESTION, | ||
COMPARATOR_CSV_HEADER_ANSWER, | ||
COMPARATOR_CSV_HEADER_RANK, | ||
COMPARATOR_CSV_HEADER_TIMESTAMP, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
~/pykoi
directory might not exist depending on where user clone the repo. You can say pykoi root directory.