From 5da603fdbd710787ca2f0a39b0798575150042cd Mon Sep 17 00:00:00 2001 From: ZimpleX <zhqhku@gmail.com> Date: Mon, 18 Sep 2023 22:40:02 -0700 Subject: [PATCH] add unit tests --- tests/__init__.py | 0 tests/chat/__init__.py | 0 tests/chat/db/__init__.py | 0 tests/chat/db/test_qa_database.py | 129 +++++++++++++++++++++++++ tests/chat/db/test_ranking_database.py | 116 ++++++++++++++++++++++ tests/chat/llm/__init__.py | 0 tests/chat/llm/test_abs_llm.py | 40 ++++++++ tests/chat/llm/test_constants.py | 30 ++++++ tests/chat/llm/test_huggingface.py | 40 ++++++++ tests/chat/llm/test_openai.py | 54 +++++++++++ tests/component/__init__.py | 0 tests/component/test_base.py | 77 +++++++++++++++ tests/test_application.py | 36 +++++++ tests/test_state.py | 86 +++++++++++++++++ 14 files changed, 608 insertions(+) create mode 100644 tests/__init__.py create mode 100644 tests/chat/__init__.py create mode 100644 tests/chat/db/__init__.py create mode 100644 tests/chat/db/test_qa_database.py create mode 100644 tests/chat/db/test_ranking_database.py create mode 100644 tests/chat/llm/__init__.py create mode 100644 tests/chat/llm/test_abs_llm.py create mode 100644 tests/chat/llm/test_constants.py create mode 100644 tests/chat/llm/test_huggingface.py create mode 100644 tests/chat/llm/test_openai.py create mode 100644 tests/component/__init__.py create mode 100644 tests/component/test_base.py create mode 100644 tests/test_application.py create mode 100644 tests/test_state.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/chat/__init__.py b/tests/chat/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/chat/db/__init__.py b/tests/chat/db/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/chat/db/test_qa_database.py b/tests/chat/db/test_qa_database.py new file mode 100644 index 0000000..d40368b --- /dev/null +++ b/tests/chat/db/test_qa_database.py @@ -0,0 +1,129 @@ +"""Test the QuestionAnswerDatabase class""" +import datetime +import os +import sqlite3 +import unittest + +from pykoi.chat.db.qa_database import QuestionAnswerDatabase + +# Define a temporary database file for testing +TEST_DB_FILE = "test_qd.db" + + +class TestQuestionAnswerDatabase(unittest.TestCase): + """ + Test the QuestionAnswerDatabase class. + """ + + def setUp(self): + # Create a temporary database for testing + self.qadb = QuestionAnswerDatabase(db_file=TEST_DB_FILE, debug=False) + + def tearDown(self): + # Remove the temporary database and close connections after each test + self.qadb.close_connection() + os.remove(TEST_DB_FILE) + + def test_create_table(self): + """ + Test whether the table is created correctly. + """ + # Test whether the table is created correctly + conn = sqlite3.connect(TEST_DB_FILE) + cursor = conn.cursor() + + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='question_answer'" + ) + table_exists = cursor.fetchone() + + self.assertTrue(table_exists) + + # Clean up + cursor.close() + conn.close() + + def test_insert_and_retrieve_question_answer(self): + """ + Test inserting and retrieving a question-answer pair + """ + question = "What is the meaning of life?" + answer = "42" + + # Insert data and get the ID + qa_id = self.qadb.insert_question_answer(question, answer) + + # Retrieve the data + rows = self.qadb.retrieve_all_question_answers() + + # Check if the data was inserted correctly + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0][0], qa_id) + self.assertEqual(rows[0][1], question) + self.assertEqual(rows[0][2], answer) + self.assertEqual(rows[0][3], "n/a") # Default vote status + + def test_update_vote_status(self): + """ + Test updating the vote status of a question-answer pair. + """ + question = "What is the meaning of life?" + answer = "42" + + # Insert data and get the ID + qa_id = self.qadb.insert_question_answer(question, answer) + + # Update the vote status + new_vote_status = "up" + self.qadb.update_vote_status(qa_id, new_vote_status) + + # Retrieve the data + rows = self.qadb.retrieve_all_question_answers() + + # Check if the vote status was updated correctly + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0][0], qa_id) + self.assertEqual(rows[0][3], new_vote_status) + + def test_save_to_csv(self): + """ + Test saving data to a CSV file + """ + question1 = "What is the meaning of life?" + answer1 = "42" + question2 = "What is the best programming language?" + answer2 = "Python" + + # Insert data + timestamp = datetime.datetime.now() + self.qadb.insert_question_answer(question1, answer1) + self.qadb.insert_question_answer(question2, answer2) + + # Save to CSV + self.qadb.save_to_csv("test_csv_file.csv") + + # Check if the CSV file was created and contains the correct data + self.assertTrue(os.path.exists("test_csv_file.csv")) + + with open("test_csv_file.csv", "r") as file: + lines = file.readlines() + + # Verify the CSV file content + timestamp_trim = 10 # Trim 10 characters from the timestamp + self.assertEqual(len(lines), 3) # Header + 2 rows + self.assertEqual(lines[0].strip(), "ID,Question,Answer,Vote Status,Timestamp") + self.assertEqual( + lines[1].strip()[:-timestamp_trim], + f"1,{question1},{answer1},n/a,{timestamp}"[:-timestamp_trim], + ) # Default vote status + self.assertEqual( + lines[2].strip()[:-timestamp_trim], + f"2,{question2},{answer2},n/a,{timestamp}"[:-timestamp_trim], + ) # Default vote status + + # Clean up + os.remove("test_csv_file.csv") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/chat/db/test_ranking_database.py b/tests/chat/db/test_ranking_database.py new file mode 100644 index 0000000..f813e6f --- /dev/null +++ b/tests/chat/db/test_ranking_database.py @@ -0,0 +1,116 @@ +""" +Test the RankingDatabase class +""" +import os +import sqlite3 +import unittest + +from pykoi.chat.db.ranking_database import RankingDatabase + +# Define a temporary database file for testing +TEST_DB_FILE = "test_ranking.db" + + +class TestRankingDatabase(unittest.TestCase): + """ + Test the RankingDatabase class + """ + + def setUp(self): + # Create a temporary database for testing + self.ranking_db = RankingDatabase(db_file=TEST_DB_FILE, debug=False) + + def tearDown(self): + # Remove the temporary database and close connections after each test + self.ranking_db.close_connection() + os.remove(TEST_DB_FILE) + + def test_create_table(self): + """ + Test whether the table is created correctly. + """ + conn = sqlite3.connect(TEST_DB_FILE) + cursor = conn.cursor() + + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='ranking'" + ) + table_exists = cursor.fetchone() + + self.assertTrue(table_exists) + + # Clean up + cursor.close() + conn.close() + + def test_insert_and_retrieve_ranking(self): + """ + Test inserting and retrieving a ranking entry + """ + question = "Which fruit is your favorite?" + up_ranking_answer = "Apple" + low_ranking_answer = "Banana" + + # Insert data and get the ID + ranking_id = self.ranking_db.insert_ranking( + question, up_ranking_answer, low_ranking_answer + ) + + # Retrieve the data + rows = self.ranking_db.retrieve_all_question_answers() + + # Check if the data was inserted correctly + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0][0], ranking_id) + self.assertEqual(rows[0][1], question) + self.assertEqual(rows[0][2], up_ranking_answer) + self.assertEqual(rows[0][3], low_ranking_answer) + + def test_save_to_csv(self): + """ + Test saving data to a CSV file + """ + question1 = "Which fruit is your favorite?" + up_ranking_answer1 = "Apple" + low_ranking_answer1 = "Banana" + question2 = "Which country would you like to visit?" + up_ranking_answer2 = "Japan" + low_ranking_answer2 = "Italy" + + # Insert data + self.ranking_db.insert_ranking( + question1, up_ranking_answer1, low_ranking_answer1 + ) + self.ranking_db.insert_ranking( + question2, up_ranking_answer2, low_ranking_answer2 + ) + + # Save to CSV + self.ranking_db.save_to_csv("test_csv_file.csv") + + # Check if the CSV file was created and contains the correct data + self.assertTrue(os.path.exists("test_csv_file.csv")) + + with open("test_csv_file.csv", "r") as file: + lines = file.readlines() + + # Verify the CSV file content + self.assertEqual(len(lines), 3) # Header + 2 rows + self.assertEqual( + lines[0].strip(), "ID,Question,Up Ranking Answer,Low Ranking Answer" + ) + self.assertEqual( + lines[1].strip(), + f"1,{question1},{up_ranking_answer1},{low_ranking_answer1}", + ) + self.assertEqual( + lines[2].strip(), + f"2,{question2},{up_ranking_answer2},{low_ranking_answer2}", + ) + + # Clean up + os.remove("test_csv_file.csv") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/chat/llm/__init__.py b/tests/chat/llm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/chat/llm/test_abs_llm.py b/tests/chat/llm/test_abs_llm.py new file mode 100644 index 0000000..3a8ce02 --- /dev/null +++ b/tests/chat/llm/test_abs_llm.py @@ -0,0 +1,40 @@ +""" +Test the AbsLlm class. +""" +import unittest + +from pykoi.chat.llm.abs_llm import AbsLlm + + +class DummyLlm(AbsLlm): + """ + Dummy class for testing the abstract base class AbsLlm. + """ + + def predict(self, message: str): + return f"Q: {message}, A: N/A." + + +class TestAbsLlm(unittest.TestCase): + """ + Test the AbsLlm class. + """ + + def test_predict_abstract_method(self): + """ + Test whether the abstract method `predict` raises NotImplementedError + """ + + test_message = "test" + llm = DummyLlm() + self.assertEqual(llm.predict(test_message), f"Q: {test_message}, A: N/A.") + + def test_docstring(self): + """ + Test whether the class has a docstring + """ + self.assertIsNotNone(AbsLlm.__doc__) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/chat/llm/test_constants.py b/tests/chat/llm/test_constants.py new file mode 100644 index 0000000..2e3e71f --- /dev/null +++ b/tests/chat/llm/test_constants.py @@ -0,0 +1,30 @@ +""" +Test the constants of the LLM module. +""" +import unittest + +from pykoi.chat.llm.constants import ModelSource + + +class TestLlmName(unittest.TestCase): + """ + Test the ModelSource enum. + """ + + def test_enum_values(self): + """ + Test whether the enum values are defined correctly + """ + self.assertEqual(ModelSource.OPENAI.value, "openai") + self.assertEqual(ModelSource.HUGGINGFACE.value, "huggingface") + + def test_enum_attributes(self): + """ + Test whether the enum attributes are defined correctly + """ + self.assertEqual(ModelSource.OPENAI.name, "OPENAI") + self.assertEqual(ModelSource.HUGGINGFACE.name, "HUGGINGFACE") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/chat/llm/test_huggingface.py b/tests/chat/llm/test_huggingface.py new file mode 100644 index 0000000..77e2d74 --- /dev/null +++ b/tests/chat/llm/test_huggingface.py @@ -0,0 +1,40 @@ +import unittest +from unittest.mock import patch, Mock +from transformers import AutoModelForCausalLM, AutoTokenizer + +from pykoi.chat.llm.abs_llm import AbsLlm +from pykoi.chat.llm.huggingface import HuggingfaceModel + + +class TestHuggingfaceModel(unittest.TestCase): + @patch.object(AutoModelForCausalLM, "from_pretrained") + @patch.object(AutoTokenizer, "from_pretrained") + def setUp(self, mock_model, mock_tokenizer): + self.model_name = "gpt2" + self.mock_model = mock_model + self.mock_tokenizer = mock_tokenizer + + # Mocking the pretrained model and tokenizer + self.mock_model.return_value = Mock() + self.mock_tokenizer.return_value = Mock() + + self.huggingface_model = HuggingfaceModel( + pretrained_model_name_or_path=self.model_name + ) + + def test_name(self): + expected_name = f"{HuggingfaceModel.model_source}_{self.model_name}_100" + self.assertEqual(self.huggingface_model.name, expected_name) + + @patch.object(HuggingfaceModel, "predict") + def test_predict(self, mock_predict): + mock_predict.return_value = ["Hello, how can I assist you today?"] + message = "Hello, chatbot!" + num_of_response = 1 + response = self.huggingface_model.predict(message, num_of_response) + mock_predict.assert_called_once_with(message, num_of_response) + self.assertEqual(response, ["Hello, how can I assist you today?"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/chat/llm/test_openai.py b/tests/chat/llm/test_openai.py new file mode 100644 index 0000000..5cf01d8 --- /dev/null +++ b/tests/chat/llm/test_openai.py @@ -0,0 +1,54 @@ +""" +Test the OpenAIModel class +""" +import unittest +from unittest.mock import MagicMock, patch + +from pykoi.chat.llm.openai import OpenAIModel + + +class TestOpenAIModel(unittest.TestCase): + """ + Test the OpenAIModel class + """ + + def test_predict(self): + """ + Test the predict method of the OpenAIModel class + """ + # Test predicting the next word based on a given message + message = "What is the meaning of life?" + predicted_word = "42" + + # Mock the OpenAI.Completion.create behavior + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].text = f"Answer: {predicted_word}" + openai_completion_create_mock = MagicMock(return_value=mock_response) + + # Patch the OpenAI.Completion.create method to use the mocked version + with patch( + "pykoi.chat.llm.openai.openai.Completion.create", openai_completion_create_mock + ): + openai_model = OpenAIModel( + api_key="fake_api_key", + engine="davinci", + max_tokens=100, + temperature=0.5, + ) + result = openai_model.predict(message, 1) + + # Check if the OpenAI.Completion.create method was called with the correct arguments + openai_completion_create_mock.assert_called_once_with( + engine="davinci", + prompt=f"Question: {message}\nAnswer:", + max_tokens=100, + n=1, + stop="\n", + temperature=0.5, + ) + self.assertEqual(result[0], f"Answer: {predicted_word}") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/component/__init__.py b/tests/component/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/component/test_base.py b/tests/component/test_base.py new file mode 100644 index 0000000..84df2f9 --- /dev/null +++ b/tests/component/test_base.py @@ -0,0 +1,77 @@ +"""Tests for the base module.""" +import unittest + +from pykoi.component.base import Component, DataSource + + +class TestDataSource(unittest.TestCase): + """ + Unit test class for the DataSource class. + + This class tests the initialization of the DataSource class by creating a DataSource instance + with a specific id and a fetch function. It then asserts that the id and the data fetched by + the fetch function are as expected. + + Attributes: + fetch_func (function): A function that returns the data to be fetched. + ds (DataSource): An instance of the DataSource class. + + Methods: + test_init: Tests the initialization of the DataSource class. + """ + + def test_init(self): + """ + Tests the initialization of the DataSource class. + + This method creates a DataSource instance with a specific id and a fetch function. + It then asserts that the id and the data fetched by the fetch function are as expected. + """ + + def fetch_func(): + return "data" + + data_source = DataSource("test_id", fetch_func) + + self.assertEqual(data_source.id, "test_id") + self.assertEqual(data_source.fetch_func(), "data") + + +class TestComponent(unittest.TestCase): + """ + Unit test class for the Component class. + + This class tests the initialization of the Component class by creating a Component instance + with a specific fetch function, a svelte component, and properties. It then asserts that the id, + the data fetched by the fetch function, the svelte component, and the properties are as expected. + + Attributes: + fetch_func (function): A function that returns the data to be fetched. + comp (Component): An instance of the Component class. + + Methods: + test_init: Tests the initialization of the Component class. + """ + + def test_init(self): + """ + Tests the initialization of the Component class. + + This method creates a Component instance with a specific fetch function, a svelte component, + and properties. It then asserts that the id, the data fetched by the fetch function, the svelte + component, and the properties are as expected. + """ + + def fetch_func(): + return "data" + + comp = Component(fetch_func, "TestComponent", prop1="value1", prop2="value2") + + self.assertIsNotNone(comp.id) + self.assertEqual(comp.data_source.fetch_func(), "data") + self.assertEqual(comp.svelte_component, "TestComponent") + self.assertEqual(comp.props, {"prop1": "value1", "prop2": "value2"}) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_application.py b/tests/test_application.py new file mode 100644 index 0000000..4f5f3b9 --- /dev/null +++ b/tests/test_application.py @@ -0,0 +1,36 @@ +""" +Test the Application class. +""" + +import os +import unittest +from unittest.mock import MagicMock, patch + +from pykoi.application import Application + + +class TestApplication(unittest.TestCase): + """ + Unit test class for the Application class. + """ + + def setUp(self): + self.app = Application(share=False, debug=False) + + def test_add_component(self): + """ + Tests adding a component to the application. + """ + component = MagicMock() + component.data_source = MagicMock() + component.id = "test_component" + + self.app.add_component(component) + + # Check if the component and its data source are added to the application + self.assertIn("test_component", self.app.data_sources) + self.assertIn(component, [c["component"] for c in self.app.components]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_state.py b/tests/test_state.py new file mode 100644 index 0000000..3854817 --- /dev/null +++ b/tests/test_state.py @@ -0,0 +1,86 @@ +""" +Tests for the State class. +""" +import unittest + +from pykoi.state import State, Store + + +class TestState(unittest.TestCase): + """ + Unit test class for the State class. + """ + + def test_get_attribute(self): + """ + Tests getting an attribute from the state. + """ + state = State() + state.state = {"test_attr": 42} + + self.assertEqual(state.test_attr, 42) + + def test_get_non_existing_attribute(self): + """ + Tests getting a non-existing attribute from the state. + """ + state = State() + + with self.assertRaises(AttributeError): + attr = state.non_existing_attr + + def test_set_attribute(self): + """ + Tests setting an attribute in the state. + """ + state = State() + state.test_attr = 42 + + self.assertEqual(state.state["test_attr"], 42) + + def test_call_method(self): + """ + Tests calling a method from the state. + """ + state = State() + state.state = {"test_method": lambda x: x * 2} + + result = state.test_method(3) + self.assertEqual(result, 6) + + +class TestStore(unittest.TestCase): + """ + Unit test class for the Store class. + """ + + def test_increment(self): + """ + Tests incrementing the count. + """ + store = Store() + store.increment() + + self.assertEqual(store.count, 5) + + def test_decrement(self): + """ + Tests decrementing the count. + """ + store = Store() + store.decrement() + + self.assertEqual(store.count, 3) + + def test_hello(self): + """ + Tests the hello method. + """ + store = Store() + hello_msg = store.hello() + + self.assertEqual(hello_msg, "hello jared") + + +if __name__ == "__main__": + unittest.main()