Skip to content

Commit

Permalink
update save_to_csv function for qa_database and rag_database
Browse files Browse the repository at this point in the history
  • Loading branch information
jojortz committed Sep 27, 2023
1 parent 0bf1e8a commit cb265d6
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
6 changes: 3 additions & 3 deletions pykoi/chat/db/qa_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,13 @@ def print_table(self, rows):
f"Answer: {row[2]}, Vote Status: {row[3]}, Timestamp: {row[4]}"
)

def save_to_csv(self, csv_file_name="question_answer_votes.csv"):
def save_to_csv(self, csv_file_name="question_answer_votes"):
"""
This method saves the contents of the question_answer table into a CSV file.
Args:
csv_file_name (str, optional): The name of the CSV file to which the data will be written.
Defaults to "question_answer_votes.csv".
Defaults to "question_answer_votes".
The CSV file will have the following columns: ID, Question, Answer, Vote Status. Each row in the
CSV file corresponds to a row in the question_answer table.
Expand All @@ -196,7 +196,7 @@ def save_to_csv(self, csv_file_name="question_answer_votes.csv"):
"""
my_sql_data = self.retrieve_all_question_answers()

with open(csv_file_name, "w", newline="") as file:
with open(csv_file_name + ".csv", "w", newline="") as file:
writer = csv.writer(file)
writer.writerow(QA_CSV_HEADER)
writer.writerows(my_sql_data)
20 changes: 10 additions & 10 deletions pykoi/chat/db/rag_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pandas as pd

from pykoi.chat.db.constants import QA_CSV_HEADER, QA_LIST_SEPARATOR
from pykoi.chat.db.constants import RAG_CSV_HEADER, RAG_LIST_SEPARATOR


class RAGDatabase:
Expand Down Expand Up @@ -93,9 +93,9 @@ def insert_question_answer(self, question: str, answer: str, rag_sources: list,
INSERT INTO rag_question_answer (question, answer, edited_answer, rag_sources, source, source_content, vote_status, timestamp)
VALUES (?, ?, '', ?, ?, ?, 'n/a', ?);
"""
rag_sources = QA_LIST_SEPARATOR.join(rag_sources)
source = QA_LIST_SEPARATOR.join(source)
source_content = QA_LIST_SEPARATOR.join(source_content)
rag_sources = RAG_LIST_SEPARATOR.join(rag_sources)
source = RAG_LIST_SEPARATOR.join(source)
source_content = RAG_LIST_SEPARATOR.join(source_content)
print("rag insert question answer", rag_sources)

with self._lock:
Expand Down Expand Up @@ -222,23 +222,23 @@ def print_table(self, rows):
f"Answer: {row[2]}, Vote Status: {row[3]}, Timestamp: {row[4]}"
)

def save_to_csv(self, csv_file_name="question_answer_votes.csv"):
def save_to_csv(self, csv_file_name="rag_table"):
"""
This method saves the contents of the question_answer table into a CSV file.
This method saves the contents of the RAG table into a CSV file.
Args:
csv_file_name (str, optional): The name of the CSV file to which the data will be written.
Defaults to "question_answer_votes.csv".
Defaults to "rag_table".
The CSV file will have the following columns: ID, Question, Answer, Vote Status. Each row in the
The CSV file will have the following columns: ID, Question, Answer, Edited Answer, Vote Status, RAG Sources, Source, Source Content, and Timestamp. Each row in the
CSV file corresponds to a row in the question_answer table.
This method first retrieves all question-answer pairs from the database by calling the
retrieve_all_question_answers method. It then writes this data to the CSV file.
"""
my_sql_data = self.retrieve_all_question_answers()

with open(csv_file_name, "w", newline="") as file:
with open(csv_file_name + ".csv", "w", newline="") as file:
writer = csv.writer(file)
writer.writerow(QA_CSV_HEADER)
writer.writerow(RAG_CSV_HEADER)
writer.writerows(my_sql_data)

0 comments on commit cb265d6

Please sign in to comment.