-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
608 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
"""Test the QuestionAnswerDatabase class""" | ||
import datetime | ||
import os | ||
import sqlite3 | ||
import unittest | ||
|
||
from pykoi.chat.db.qa_database import QuestionAnswerDatabase | ||
|
||
# Define a temporary database file for testing | ||
TEST_DB_FILE = "test_qd.db" | ||
|
||
|
||
class TestQuestionAnswerDatabase(unittest.TestCase): | ||
""" | ||
Test the QuestionAnswerDatabase class. | ||
""" | ||
|
||
def setUp(self): | ||
# Create a temporary database for testing | ||
self.qadb = QuestionAnswerDatabase(db_file=TEST_DB_FILE, debug=False) | ||
|
||
def tearDown(self): | ||
# Remove the temporary database and close connections after each test | ||
self.qadb.close_connection() | ||
os.remove(TEST_DB_FILE) | ||
|
||
def test_create_table(self): | ||
""" | ||
Test whether the table is created correctly. | ||
""" | ||
# Test whether the table is created correctly | ||
conn = sqlite3.connect(TEST_DB_FILE) | ||
cursor = conn.cursor() | ||
|
||
cursor.execute( | ||
"SELECT name FROM sqlite_master WHERE type='table' AND name='question_answer'" | ||
) | ||
table_exists = cursor.fetchone() | ||
|
||
self.assertTrue(table_exists) | ||
|
||
# Clean up | ||
cursor.close() | ||
conn.close() | ||
|
||
def test_insert_and_retrieve_question_answer(self): | ||
""" | ||
Test inserting and retrieving a question-answer pair | ||
""" | ||
question = "What is the meaning of life?" | ||
answer = "42" | ||
|
||
# Insert data and get the ID | ||
qa_id = self.qadb.insert_question_answer(question, answer) | ||
|
||
# Retrieve the data | ||
rows = self.qadb.retrieve_all_question_answers() | ||
|
||
# Check if the data was inserted correctly | ||
self.assertEqual(len(rows), 1) | ||
self.assertEqual(rows[0][0], qa_id) | ||
self.assertEqual(rows[0][1], question) | ||
self.assertEqual(rows[0][2], answer) | ||
self.assertEqual(rows[0][3], "n/a") # Default vote status | ||
|
||
def test_update_vote_status(self): | ||
""" | ||
Test updating the vote status of a question-answer pair. | ||
""" | ||
question = "What is the meaning of life?" | ||
answer = "42" | ||
|
||
# Insert data and get the ID | ||
qa_id = self.qadb.insert_question_answer(question, answer) | ||
|
||
# Update the vote status | ||
new_vote_status = "up" | ||
self.qadb.update_vote_status(qa_id, new_vote_status) | ||
|
||
# Retrieve the data | ||
rows = self.qadb.retrieve_all_question_answers() | ||
|
||
# Check if the vote status was updated correctly | ||
self.assertEqual(len(rows), 1) | ||
self.assertEqual(rows[0][0], qa_id) | ||
self.assertEqual(rows[0][3], new_vote_status) | ||
|
||
def test_save_to_csv(self): | ||
""" | ||
Test saving data to a CSV file | ||
""" | ||
question1 = "What is the meaning of life?" | ||
answer1 = "42" | ||
question2 = "What is the best programming language?" | ||
answer2 = "Python" | ||
|
||
# Insert data | ||
timestamp = datetime.datetime.now() | ||
self.qadb.insert_question_answer(question1, answer1) | ||
self.qadb.insert_question_answer(question2, answer2) | ||
|
||
# Save to CSV | ||
self.qadb.save_to_csv("test_csv_file.csv") | ||
|
||
# Check if the CSV file was created and contains the correct data | ||
self.assertTrue(os.path.exists("test_csv_file.csv")) | ||
|
||
with open("test_csv_file.csv", "r") as file: | ||
lines = file.readlines() | ||
|
||
# Verify the CSV file content | ||
timestamp_trim = 10 # Trim 10 characters from the timestamp | ||
self.assertEqual(len(lines), 3) # Header + 2 rows | ||
self.assertEqual(lines[0].strip(), "ID,Question,Answer,Vote Status,Timestamp") | ||
self.assertEqual( | ||
lines[1].strip()[:-timestamp_trim], | ||
f"1,{question1},{answer1},n/a,{timestamp}"[:-timestamp_trim], | ||
) # Default vote status | ||
self.assertEqual( | ||
lines[2].strip()[:-timestamp_trim], | ||
f"2,{question2},{answer2},n/a,{timestamp}"[:-timestamp_trim], | ||
) # Default vote status | ||
|
||
# Clean up | ||
os.remove("test_csv_file.csv") | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
""" | ||
Test the RankingDatabase class | ||
""" | ||
import os | ||
import sqlite3 | ||
import unittest | ||
|
||
from pykoi.chat.db.ranking_database import RankingDatabase | ||
|
||
# Define a temporary database file for testing | ||
TEST_DB_FILE = "test_ranking.db" | ||
|
||
|
||
class TestRankingDatabase(unittest.TestCase): | ||
""" | ||
Test the RankingDatabase class | ||
""" | ||
|
||
def setUp(self): | ||
# Create a temporary database for testing | ||
self.ranking_db = RankingDatabase(db_file=TEST_DB_FILE, debug=False) | ||
|
||
def tearDown(self): | ||
# Remove the temporary database and close connections after each test | ||
self.ranking_db.close_connection() | ||
os.remove(TEST_DB_FILE) | ||
|
||
def test_create_table(self): | ||
""" | ||
Test whether the table is created correctly. | ||
""" | ||
conn = sqlite3.connect(TEST_DB_FILE) | ||
cursor = conn.cursor() | ||
|
||
cursor.execute( | ||
"SELECT name FROM sqlite_master WHERE type='table' AND name='ranking'" | ||
) | ||
table_exists = cursor.fetchone() | ||
|
||
self.assertTrue(table_exists) | ||
|
||
# Clean up | ||
cursor.close() | ||
conn.close() | ||
|
||
def test_insert_and_retrieve_ranking(self): | ||
""" | ||
Test inserting and retrieving a ranking entry | ||
""" | ||
question = "Which fruit is your favorite?" | ||
up_ranking_answer = "Apple" | ||
low_ranking_answer = "Banana" | ||
|
||
# Insert data and get the ID | ||
ranking_id = self.ranking_db.insert_ranking( | ||
question, up_ranking_answer, low_ranking_answer | ||
) | ||
|
||
# Retrieve the data | ||
rows = self.ranking_db.retrieve_all_question_answers() | ||
|
||
# Check if the data was inserted correctly | ||
self.assertEqual(len(rows), 1) | ||
self.assertEqual(rows[0][0], ranking_id) | ||
self.assertEqual(rows[0][1], question) | ||
self.assertEqual(rows[0][2], up_ranking_answer) | ||
self.assertEqual(rows[0][3], low_ranking_answer) | ||
|
||
def test_save_to_csv(self): | ||
""" | ||
Test saving data to a CSV file | ||
""" | ||
question1 = "Which fruit is your favorite?" | ||
up_ranking_answer1 = "Apple" | ||
low_ranking_answer1 = "Banana" | ||
question2 = "Which country would you like to visit?" | ||
up_ranking_answer2 = "Japan" | ||
low_ranking_answer2 = "Italy" | ||
|
||
# Insert data | ||
self.ranking_db.insert_ranking( | ||
question1, up_ranking_answer1, low_ranking_answer1 | ||
) | ||
self.ranking_db.insert_ranking( | ||
question2, up_ranking_answer2, low_ranking_answer2 | ||
) | ||
|
||
# Save to CSV | ||
self.ranking_db.save_to_csv("test_csv_file.csv") | ||
|
||
# Check if the CSV file was created and contains the correct data | ||
self.assertTrue(os.path.exists("test_csv_file.csv")) | ||
|
||
with open("test_csv_file.csv", "r") as file: | ||
lines = file.readlines() | ||
|
||
# Verify the CSV file content | ||
self.assertEqual(len(lines), 3) # Header + 2 rows | ||
self.assertEqual( | ||
lines[0].strip(), "ID,Question,Up Ranking Answer,Low Ranking Answer" | ||
) | ||
self.assertEqual( | ||
lines[1].strip(), | ||
f"1,{question1},{up_ranking_answer1},{low_ranking_answer1}", | ||
) | ||
self.assertEqual( | ||
lines[2].strip(), | ||
f"2,{question2},{up_ranking_answer2},{low_ranking_answer2}", | ||
) | ||
|
||
# Clean up | ||
os.remove("test_csv_file.csv") | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
""" | ||
Test the AbsLlm class. | ||
""" | ||
import unittest | ||
|
||
from pykoi.chat.llm.abs_llm import AbsLlm | ||
|
||
|
||
class DummyLlm(AbsLlm): | ||
""" | ||
Dummy class for testing the abstract base class AbsLlm. | ||
""" | ||
|
||
def predict(self, message: str): | ||
return f"Q: {message}, A: N/A." | ||
|
||
|
||
class TestAbsLlm(unittest.TestCase): | ||
""" | ||
Test the AbsLlm class. | ||
""" | ||
|
||
def test_predict_abstract_method(self): | ||
""" | ||
Test whether the abstract method `predict` raises NotImplementedError | ||
""" | ||
|
||
test_message = "test" | ||
llm = DummyLlm() | ||
self.assertEqual(llm.predict(test_message), f"Q: {test_message}, A: N/A.") | ||
|
||
def test_docstring(self): | ||
""" | ||
Test whether the class has a docstring | ||
""" | ||
self.assertIsNotNone(AbsLlm.__doc__) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
""" | ||
Test the constants of the LLM module. | ||
""" | ||
import unittest | ||
|
||
from pykoi.chat.llm.constants import ModelSource | ||
|
||
|
||
class TestLlmName(unittest.TestCase): | ||
""" | ||
Test the ModelSource enum. | ||
""" | ||
|
||
def test_enum_values(self): | ||
""" | ||
Test whether the enum values are defined correctly | ||
""" | ||
self.assertEqual(ModelSource.OPENAI.value, "openai") | ||
self.assertEqual(ModelSource.HUGGINGFACE.value, "huggingface") | ||
|
||
def test_enum_attributes(self): | ||
""" | ||
Test whether the enum attributes are defined correctly | ||
""" | ||
self.assertEqual(ModelSource.OPENAI.name, "OPENAI") | ||
self.assertEqual(ModelSource.HUGGINGFACE.name, "HUGGINGFACE") | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import unittest | ||
from unittest.mock import patch, Mock | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from pykoi.chat.llm.abs_llm import AbsLlm | ||
from pykoi.chat.llm.huggingface import HuggingfaceModel | ||
|
||
|
||
class TestHuggingfaceModel(unittest.TestCase): | ||
@patch.object(AutoModelForCausalLM, "from_pretrained") | ||
@patch.object(AutoTokenizer, "from_pretrained") | ||
def setUp(self, mock_model, mock_tokenizer): | ||
self.model_name = "gpt2" | ||
self.mock_model = mock_model | ||
self.mock_tokenizer = mock_tokenizer | ||
|
||
# Mocking the pretrained model and tokenizer | ||
self.mock_model.return_value = Mock() | ||
self.mock_tokenizer.return_value = Mock() | ||
|
||
self.huggingface_model = HuggingfaceModel( | ||
pretrained_model_name_or_path=self.model_name | ||
) | ||
|
||
def test_name(self): | ||
expected_name = f"{HuggingfaceModel.model_source}_{self.model_name}_100" | ||
self.assertEqual(self.huggingface_model.name, expected_name) | ||
|
||
@patch.object(HuggingfaceModel, "predict") | ||
def test_predict(self, mock_predict): | ||
mock_predict.return_value = ["Hello, how can I assist you today?"] | ||
message = "Hello, chatbot!" | ||
num_of_response = 1 | ||
response = self.huggingface_model.predict(message, num_of_response) | ||
mock_predict.assert_called_once_with(message, num_of_response) | ||
self.assertEqual(response, ["Hello, how can I assist you today?"]) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.