Skip to content

Commit

Permalink
add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ZimpleX committed Sep 19, 2023
1 parent ae5ed07 commit 5da603f
Show file tree
Hide file tree
Showing 14 changed files with 608 additions and 0 deletions.
Empty file added tests/__init__.py
Empty file.
Empty file added tests/chat/__init__.py
Empty file.
Empty file added tests/chat/db/__init__.py
Empty file.
129 changes: 129 additions & 0 deletions tests/chat/db/test_qa_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""Test the QuestionAnswerDatabase class"""
import datetime
import os
import sqlite3
import unittest

from pykoi.chat.db.qa_database import QuestionAnswerDatabase

# Define a temporary database file for testing
TEST_DB_FILE = "test_qd.db"


class TestQuestionAnswerDatabase(unittest.TestCase):
"""
Test the QuestionAnswerDatabase class.
"""

def setUp(self):
# Create a temporary database for testing
self.qadb = QuestionAnswerDatabase(db_file=TEST_DB_FILE, debug=False)

def tearDown(self):
# Remove the temporary database and close connections after each test
self.qadb.close_connection()
os.remove(TEST_DB_FILE)

def test_create_table(self):
"""
Test whether the table is created correctly.
"""
# Test whether the table is created correctly
conn = sqlite3.connect(TEST_DB_FILE)
cursor = conn.cursor()

cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='question_answer'"
)
table_exists = cursor.fetchone()

self.assertTrue(table_exists)

# Clean up
cursor.close()
conn.close()

def test_insert_and_retrieve_question_answer(self):
"""
Test inserting and retrieving a question-answer pair
"""
question = "What is the meaning of life?"
answer = "42"

# Insert data and get the ID
qa_id = self.qadb.insert_question_answer(question, answer)

# Retrieve the data
rows = self.qadb.retrieve_all_question_answers()

# Check if the data was inserted correctly
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0][0], qa_id)
self.assertEqual(rows[0][1], question)
self.assertEqual(rows[0][2], answer)
self.assertEqual(rows[0][3], "n/a") # Default vote status

def test_update_vote_status(self):
"""
Test updating the vote status of a question-answer pair.
"""
question = "What is the meaning of life?"
answer = "42"

# Insert data and get the ID
qa_id = self.qadb.insert_question_answer(question, answer)

# Update the vote status
new_vote_status = "up"
self.qadb.update_vote_status(qa_id, new_vote_status)

# Retrieve the data
rows = self.qadb.retrieve_all_question_answers()

# Check if the vote status was updated correctly
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0][0], qa_id)
self.assertEqual(rows[0][3], new_vote_status)

def test_save_to_csv(self):
"""
Test saving data to a CSV file
"""
question1 = "What is the meaning of life?"
answer1 = "42"
question2 = "What is the best programming language?"
answer2 = "Python"

# Insert data
timestamp = datetime.datetime.now()
self.qadb.insert_question_answer(question1, answer1)
self.qadb.insert_question_answer(question2, answer2)

# Save to CSV
self.qadb.save_to_csv("test_csv_file.csv")

# Check if the CSV file was created and contains the correct data
self.assertTrue(os.path.exists("test_csv_file.csv"))

with open("test_csv_file.csv", "r") as file:
lines = file.readlines()

# Verify the CSV file content
timestamp_trim = 10 # Trim 10 characters from the timestamp
self.assertEqual(len(lines), 3) # Header + 2 rows
self.assertEqual(lines[0].strip(), "ID,Question,Answer,Vote Status,Timestamp")
self.assertEqual(
lines[1].strip()[:-timestamp_trim],
f"1,{question1},{answer1},n/a,{timestamp}"[:-timestamp_trim],
) # Default vote status
self.assertEqual(
lines[2].strip()[:-timestamp_trim],
f"2,{question2},{answer2},n/a,{timestamp}"[:-timestamp_trim],
) # Default vote status

# Clean up
os.remove("test_csv_file.csv")


if __name__ == "__main__":
unittest.main()
116 changes: 116 additions & 0 deletions tests/chat/db/test_ranking_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""
Test the RankingDatabase class
"""
import os
import sqlite3
import unittest

from pykoi.chat.db.ranking_database import RankingDatabase

# Define a temporary database file for testing
TEST_DB_FILE = "test_ranking.db"


class TestRankingDatabase(unittest.TestCase):
"""
Test the RankingDatabase class
"""

def setUp(self):
# Create a temporary database for testing
self.ranking_db = RankingDatabase(db_file=TEST_DB_FILE, debug=False)

def tearDown(self):
# Remove the temporary database and close connections after each test
self.ranking_db.close_connection()
os.remove(TEST_DB_FILE)

def test_create_table(self):
"""
Test whether the table is created correctly.
"""
conn = sqlite3.connect(TEST_DB_FILE)
cursor = conn.cursor()

cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='ranking'"
)
table_exists = cursor.fetchone()

self.assertTrue(table_exists)

# Clean up
cursor.close()
conn.close()

def test_insert_and_retrieve_ranking(self):
"""
Test inserting and retrieving a ranking entry
"""
question = "Which fruit is your favorite?"
up_ranking_answer = "Apple"
low_ranking_answer = "Banana"

# Insert data and get the ID
ranking_id = self.ranking_db.insert_ranking(
question, up_ranking_answer, low_ranking_answer
)

# Retrieve the data
rows = self.ranking_db.retrieve_all_question_answers()

# Check if the data was inserted correctly
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0][0], ranking_id)
self.assertEqual(rows[0][1], question)
self.assertEqual(rows[0][2], up_ranking_answer)
self.assertEqual(rows[0][3], low_ranking_answer)

def test_save_to_csv(self):
"""
Test saving data to a CSV file
"""
question1 = "Which fruit is your favorite?"
up_ranking_answer1 = "Apple"
low_ranking_answer1 = "Banana"
question2 = "Which country would you like to visit?"
up_ranking_answer2 = "Japan"
low_ranking_answer2 = "Italy"

# Insert data
self.ranking_db.insert_ranking(
question1, up_ranking_answer1, low_ranking_answer1
)
self.ranking_db.insert_ranking(
question2, up_ranking_answer2, low_ranking_answer2
)

# Save to CSV
self.ranking_db.save_to_csv("test_csv_file.csv")

# Check if the CSV file was created and contains the correct data
self.assertTrue(os.path.exists("test_csv_file.csv"))

with open("test_csv_file.csv", "r") as file:
lines = file.readlines()

# Verify the CSV file content
self.assertEqual(len(lines), 3) # Header + 2 rows
self.assertEqual(
lines[0].strip(), "ID,Question,Up Ranking Answer,Low Ranking Answer"
)
self.assertEqual(
lines[1].strip(),
f"1,{question1},{up_ranking_answer1},{low_ranking_answer1}",
)
self.assertEqual(
lines[2].strip(),
f"2,{question2},{up_ranking_answer2},{low_ranking_answer2}",
)

# Clean up
os.remove("test_csv_file.csv")


if __name__ == "__main__":
unittest.main()
Empty file added tests/chat/llm/__init__.py
Empty file.
40 changes: 40 additions & 0 deletions tests/chat/llm/test_abs_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
Test the AbsLlm class.
"""
import unittest

from pykoi.chat.llm.abs_llm import AbsLlm


class DummyLlm(AbsLlm):
"""
Dummy class for testing the abstract base class AbsLlm.
"""

def predict(self, message: str):
return f"Q: {message}, A: N/A."


class TestAbsLlm(unittest.TestCase):
"""
Test the AbsLlm class.
"""

def test_predict_abstract_method(self):
"""
Test whether the abstract method `predict` raises NotImplementedError
"""

test_message = "test"
llm = DummyLlm()
self.assertEqual(llm.predict(test_message), f"Q: {test_message}, A: N/A.")

def test_docstring(self):
"""
Test whether the class has a docstring
"""
self.assertIsNotNone(AbsLlm.__doc__)


if __name__ == "__main__":
unittest.main()
30 changes: 30 additions & 0 deletions tests/chat/llm/test_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""
Test the constants of the LLM module.
"""
import unittest

from pykoi.chat.llm.constants import ModelSource


class TestLlmName(unittest.TestCase):
"""
Test the ModelSource enum.
"""

def test_enum_values(self):
"""
Test whether the enum values are defined correctly
"""
self.assertEqual(ModelSource.OPENAI.value, "openai")
self.assertEqual(ModelSource.HUGGINGFACE.value, "huggingface")

def test_enum_attributes(self):
"""
Test whether the enum attributes are defined correctly
"""
self.assertEqual(ModelSource.OPENAI.name, "OPENAI")
self.assertEqual(ModelSource.HUGGINGFACE.name, "HUGGINGFACE")


if __name__ == "__main__":
unittest.main()
40 changes: 40 additions & 0 deletions tests/chat/llm/test_huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import unittest
from unittest.mock import patch, Mock
from transformers import AutoModelForCausalLM, AutoTokenizer

from pykoi.chat.llm.abs_llm import AbsLlm
from pykoi.chat.llm.huggingface import HuggingfaceModel


class TestHuggingfaceModel(unittest.TestCase):
@patch.object(AutoModelForCausalLM, "from_pretrained")
@patch.object(AutoTokenizer, "from_pretrained")
def setUp(self, mock_model, mock_tokenizer):
self.model_name = "gpt2"
self.mock_model = mock_model
self.mock_tokenizer = mock_tokenizer

# Mocking the pretrained model and tokenizer
self.mock_model.return_value = Mock()
self.mock_tokenizer.return_value = Mock()

self.huggingface_model = HuggingfaceModel(
pretrained_model_name_or_path=self.model_name
)

def test_name(self):
expected_name = f"{HuggingfaceModel.model_source}_{self.model_name}_100"
self.assertEqual(self.huggingface_model.name, expected_name)

@patch.object(HuggingfaceModel, "predict")
def test_predict(self, mock_predict):
mock_predict.return_value = ["Hello, how can I assist you today?"]
message = "Hello, chatbot!"
num_of_response = 1
response = self.huggingface_model.predict(message, num_of_response)
mock_predict.assert_called_once_with(message, num_of_response)
self.assertEqual(response, ["Hello, how can I assist you today?"])


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 5da603f

Please sign in to comment.