Skip to content

Commit

Permalink
add save_to_csv functionality for Comparator database
Browse files Browse the repository at this point in the history
  • Loading branch information
jojortz committed Oct 3, 2023
1 parent 9f046af commit c940692
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 83 deletions.
24 changes: 21 additions & 3 deletions pykoi/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ class QATableToCSV(BaseModel):
class RAGTableToCSV(BaseModel):
file_name: str

class ComparatorTableToCSV(BaseModel):
file_name: str

class UserInDB:
def __init__(self, username: str, hashed_password: str):
self.username = username
Expand Down Expand Up @@ -458,17 +461,19 @@ async def retrieve_comparator(
user: Union[None, UserInDB] = Depends(self.get_auth_dependency())
):
try:
rows = component["component"].comparator_db.retrieve_all()
rows = component["component"].comparator_db.retrieve_all_question_answers()
data = []
for row in rows:
_, model_name, qid, rank, answer, _ = row
a_id, model_name, qid, question, answer, rank, _ = row

data.append(
{
"id": a_id,
"model": model_name,
"qid": qid,
"rank": rank,
"question": question,
"answer": answer,
"rank": rank,
}
)
return {"data": data, "log": "Table retrieved", "status": "200"}
Expand All @@ -486,6 +491,19 @@ async def close_comparator(
except Exception as ex:
return {"log": f"Table close failed: {ex}", "status": "500"}

@app.post("/chat/comparator/db/save_to_csv")
async def save_comparator_table_to_csv(
request_body: ComparatorTableToCSV,
user: Union[None, UserInDB] = Depends(self.get_auth_dependency()),
):
try:
print("Saving Comparator to CSV", request_body.file_name)
component["component"].comparator_db.save_to_csv(request_body.file_name)
return {"log": f"Saved to {request_body.file_name}.csv", "status": "200"}
except Exception as ex:
return {"log": f"Save to CSV failed: {ex}", "status": "500"}


def create_qa_retrieval_route(self, app: FastAPI, component: Dict[str, Any]):
"""
Create QA retrieval routes for the application.
Expand Down
43 changes: 43 additions & 0 deletions pykoi/chat/db/comparator_database.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Comparator Database"""
import csv
import datetime
import os

from typing import List, Tuple

from pykoi.chat.db.abs_database import AbsDatabase
from pykoi.chat.db.constants import COMPARATOR_CSV_HEADER


class ComparatorQuestionDatabase(AbsDatabase):
Expand Down Expand Up @@ -200,6 +202,25 @@ def retrieve_all(self) -> List[Tuple]:
rows = cursor.fetchall()
return rows

def retrieve_all_question_answers(self):
"""
Retrieves all question-answer pairs from the database.
Returns:
rows: rows of data of the question-answer pairs.
"""
query = """
SELECT comparator.id, comparator.model, comparator.qid, comparator_question.question, comparator.answer, comparator.rank, comparator.timestamp
FROM comparator
JOIN comparator_question
ON comparator.qid = comparator_question.id;
"""
with self._lock:
cursor = self.get_cursor()
cursor.execute(query)
rows = cursor.fetchall()
return rows

def print_table(self, rows: List[Tuple]) -> None:
"""
Prints the comparator table.
Expand All @@ -217,3 +238,25 @@ def print_table(self, rows: List[Tuple]) -> None:
f"Answer: {row[4]}, "
f"Timestamp: {row[5]}"
)

def save_to_csv(self, csv_file_name="comparator_table"):
"""
This method saves the contents of the RAG table into a CSV file.
Args:
csv_file_name (str, optional): The name of the CSV file to which the data will be written.
Defaults to "comparator_table".
The CSV file will have the following columns: TODO. Each row in the
CSV file corresponds to a row in the question_answer table.
This method first retrieves all question-answer pairs from the database by calling the
retrieve_all method. It then writes this data to the CSV file.
"""

my_sql_data = self.retrieve_all_question_answers()

with open(csv_file_name + ".csv", "w", newline="") as file:
writer = csv.writer(file)
writer.writerow(COMPARATOR_CSV_HEADER)
writer.writerows(my_sql_data)
20 changes: 19 additions & 1 deletion pykoi/chat/db/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,22 @@
)

# list separator
RAG_LIST_SEPARATOR = "||"
RAG_LIST_SEPARATOR = "||"

# Comparator table
COMPARATOR_CSV_HEADER_ID = "ID"
COMPARATOR_CSV_HEADER_MODEL = "Model"
COMPARATOR_CSV_HEADER_QID = "Question ID"
COMPARATOR_CSV_HEADER_QUESTION = "Question"
COMPARATOR_CSV_HEADER_ANSWER = "Answer"
COMPARATOR_CSV_HEADER_RANK = "Rank"
COMPARATOR_CSV_HEADER_TIMESTAMP = "Timestamp"
COMPARATOR_CSV_HEADER = (
COMPARATOR_CSV_HEADER_ID,
COMPARATOR_CSV_HEADER_MODEL,
COMPARATOR_CSV_HEADER_QID,
COMPARATOR_CSV_HEADER_QUESTION,
COMPARATOR_CSV_HEADER_ANSWER,
COMPARATOR_CSV_HEADER_RANK,
COMPARATOR_CSV_HEADER_TIMESTAMP,
)
64 changes: 64 additions & 0 deletions pykoi/frontend/dist/assets/index-b2c98be3.js

Large diffs are not rendered by default.

68 changes: 0 additions & 68 deletions pykoi/frontend/dist/assets/index-c65ac7af.js

This file was deleted.

2 changes: 1 addition & 1 deletion pykoi/frontend/dist/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Vite + Svelte</title>
<script type="module" crossorigin src="/assets/index-c65ac7af.js"></script>
<script type="module" crossorigin src="/assets/index-b2c98be3.js"></script>
<link rel="stylesheet" href="/assets/index-51debcf5.css">
</head>
<body>
Expand Down
44 changes: 34 additions & 10 deletions pykoi/frontend/src/lib/Chatbots/ComparisonChat.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import { compareChatLog } from "../../store";
import Sortable from "sortablejs";
import { select } from "d3-selection";
import { mode } from "d3-array";
import DownloadModal from "./Components/DownloadModal.svelte";
export let numModels = 1;
export let models = [0];
Expand All @@ -12,6 +14,7 @@
let chatLoading = false;
$: gridTemplate = "1fr ".repeat(numModels).trim();
let answerOrder = [];
let showModal = false;
onMount(async () => {
// Give the DOM some time to render
Expand All @@ -25,20 +28,33 @@
});
answerOrder = sortable.toArray();
}
// retrieveDBData();
retrieveDBData();
});
async function retrieveDBData() {
const response = await fetch("/chat/comparator/db/retrieve");
const data = await response.json();
// const dbRows = data["data"];
// const formattedRows = dbRows.map((row) => ({
// id: row[0],
// question: row[1],
// up_ranking_answer: row[2],
// low_ranking_answer: row[3],
// }));
// $compareChatLog = [...dbRows];
console.log(data);
const dbRows = data["data"];
let formattedRows = {};
let modelSet = new Set();
for (const row of dbRows) {
modelSet.add(row["model"]);
if (formattedRows[row["qid"]]) {
formattedRows[row["qid"]][row["model"]] = row["answer"];
} else {
formattedRows[row["qid"]] = {};
formattedRows[row["qid"]]["qid"] = row["qid"];
formattedRows[row["qid"]]["question"] = row["question"];
formattedRows[row["qid"]][row["model"]] = row["answer"];
}
console.log(formattedRows);
}
models = Array.from(modelSet);
numModels = models.length;
console.log(Object.values(formattedRows))
$compareChatLog = [...Object.values(formattedRows)];
}
const askModel = async (event) => {
Expand All @@ -54,6 +70,8 @@
}
$compareChatLog = [...$compareChatLog, currentEntry];
console.log('compare chat log', compareChatLog)
const response = await fetch(`/chat/comparator/${mymessage}`, {
method: "POST",
headers: {
Expand Down Expand Up @@ -190,8 +208,14 @@
payload.push(entry);
updateComparisonDB(payload);
}
function handleDownloadClick () {
showModal = true;
}
</script>

<DownloadModal bind:showModal table="comparator/db"/>

<div class="ranked-feedback-container">
<div class="instructions">
<h5 class="underline bold">Q & A Comparison Instructions</h5>
Expand All @@ -202,7 +226,7 @@
rank for each via the corresponding dropdown.
</p>
<br />
<button>Download Data</button>
<button on:click={handleDownloadClick}>Download Data</button>
</div>
<div class="ranked-chat">
<section class="chatbox">
Expand Down

0 comments on commit c940692

Please sign in to comment.