From 5da603fdbd710787ca2f0a39b0798575150042cd Mon Sep 17 00:00:00 2001
From: ZimpleX <zhqhku@gmail.com>
Date: Mon, 18 Sep 2023 22:40:02 -0700
Subject: [PATCH] add unit tests

---
 tests/__init__.py                      |   0
 tests/chat/__init__.py                 |   0
 tests/chat/db/__init__.py              |   0
 tests/chat/db/test_qa_database.py      | 129 +++++++++++++++++++++++++
 tests/chat/db/test_ranking_database.py | 116 ++++++++++++++++++++++
 tests/chat/llm/__init__.py             |   0
 tests/chat/llm/test_abs_llm.py         |  40 ++++++++
 tests/chat/llm/test_constants.py       |  30 ++++++
 tests/chat/llm/test_huggingface.py     |  40 ++++++++
 tests/chat/llm/test_openai.py          |  54 +++++++++++
 tests/component/__init__.py            |   0
 tests/component/test_base.py           |  77 +++++++++++++++
 tests/test_application.py              |  36 +++++++
 tests/test_state.py                    |  86 +++++++++++++++++
 14 files changed, 608 insertions(+)
 create mode 100644 tests/__init__.py
 create mode 100644 tests/chat/__init__.py
 create mode 100644 tests/chat/db/__init__.py
 create mode 100644 tests/chat/db/test_qa_database.py
 create mode 100644 tests/chat/db/test_ranking_database.py
 create mode 100644 tests/chat/llm/__init__.py
 create mode 100644 tests/chat/llm/test_abs_llm.py
 create mode 100644 tests/chat/llm/test_constants.py
 create mode 100644 tests/chat/llm/test_huggingface.py
 create mode 100644 tests/chat/llm/test_openai.py
 create mode 100644 tests/component/__init__.py
 create mode 100644 tests/component/test_base.py
 create mode 100644 tests/test_application.py
 create mode 100644 tests/test_state.py

diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/chat/__init__.py b/tests/chat/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/chat/db/__init__.py b/tests/chat/db/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/chat/db/test_qa_database.py b/tests/chat/db/test_qa_database.py
new file mode 100644
index 0000000..d40368b
--- /dev/null
+++ b/tests/chat/db/test_qa_database.py
@@ -0,0 +1,129 @@
+"""Test the QuestionAnswerDatabase class"""
+import datetime
+import os
+import sqlite3
+import unittest
+
+from pykoi.chat.db.qa_database import QuestionAnswerDatabase
+
+# Define a temporary database file for testing
+TEST_DB_FILE = "test_qd.db"
+
+
+class TestQuestionAnswerDatabase(unittest.TestCase):
+    """
+    Test the QuestionAnswerDatabase class.
+    """
+
+    def setUp(self):
+        # Create a temporary database for testing
+        self.qadb = QuestionAnswerDatabase(db_file=TEST_DB_FILE, debug=False)
+
+    def tearDown(self):
+        # Remove the temporary database and close connections after each test
+        self.qadb.close_connection()
+        os.remove(TEST_DB_FILE)
+
+    def test_create_table(self):
+        """
+        Test whether the table is created correctly.
+        """
+        # Test whether the table is created correctly
+        conn = sqlite3.connect(TEST_DB_FILE)
+        cursor = conn.cursor()
+
+        cursor.execute(
+            "SELECT name FROM sqlite_master WHERE type='table' AND name='question_answer'"
+        )
+        table_exists = cursor.fetchone()
+
+        self.assertTrue(table_exists)
+
+        # Clean up
+        cursor.close()
+        conn.close()
+
+    def test_insert_and_retrieve_question_answer(self):
+        """
+        Test inserting and retrieving a question-answer pair
+        """
+        question = "What is the meaning of life?"
+        answer = "42"
+
+        # Insert data and get the ID
+        qa_id = self.qadb.insert_question_answer(question, answer)
+
+        # Retrieve the data
+        rows = self.qadb.retrieve_all_question_answers()
+
+        # Check if the data was inserted correctly
+        self.assertEqual(len(rows), 1)
+        self.assertEqual(rows[0][0], qa_id)
+        self.assertEqual(rows[0][1], question)
+        self.assertEqual(rows[0][2], answer)
+        self.assertEqual(rows[0][3], "n/a")  # Default vote status
+
+    def test_update_vote_status(self):
+        """
+        Test updating the vote status of a question-answer pair.
+        """
+        question = "What is the meaning of life?"
+        answer = "42"
+
+        # Insert data and get the ID
+        qa_id = self.qadb.insert_question_answer(question, answer)
+
+        # Update the vote status
+        new_vote_status = "up"
+        self.qadb.update_vote_status(qa_id, new_vote_status)
+
+        # Retrieve the data
+        rows = self.qadb.retrieve_all_question_answers()
+
+        # Check if the vote status was updated correctly
+        self.assertEqual(len(rows), 1)
+        self.assertEqual(rows[0][0], qa_id)
+        self.assertEqual(rows[0][3], new_vote_status)
+
+    def test_save_to_csv(self):
+        """
+        Test saving data to a CSV file
+        """
+        question1 = "What is the meaning of life?"
+        answer1 = "42"
+        question2 = "What is the best programming language?"
+        answer2 = "Python"
+
+        # Insert data
+        timestamp = datetime.datetime.now()
+        self.qadb.insert_question_answer(question1, answer1)
+        self.qadb.insert_question_answer(question2, answer2)
+
+        # Save to CSV
+        self.qadb.save_to_csv("test_csv_file.csv")
+
+        # Check if the CSV file was created and contains the correct data
+        self.assertTrue(os.path.exists("test_csv_file.csv"))
+
+        with open("test_csv_file.csv", "r") as file:
+            lines = file.readlines()
+
+        # Verify the CSV file content
+        timestamp_trim = 10  # Trim 10 characters from the timestamp
+        self.assertEqual(len(lines), 3)  # Header + 2 rows
+        self.assertEqual(lines[0].strip(), "ID,Question,Answer,Vote Status,Timestamp")
+        self.assertEqual(
+            lines[1].strip()[:-timestamp_trim],
+            f"1,{question1},{answer1},n/a,{timestamp}"[:-timestamp_trim],
+        )  # Default vote status
+        self.assertEqual(
+            lines[2].strip()[:-timestamp_trim],
+            f"2,{question2},{answer2},n/a,{timestamp}"[:-timestamp_trim],
+        )  # Default vote status
+
+        # Clean up
+        os.remove("test_csv_file.csv")
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/tests/chat/db/test_ranking_database.py b/tests/chat/db/test_ranking_database.py
new file mode 100644
index 0000000..f813e6f
--- /dev/null
+++ b/tests/chat/db/test_ranking_database.py
@@ -0,0 +1,116 @@
+"""
+Test the RankingDatabase class
+"""
+import os
+import sqlite3
+import unittest
+
+from pykoi.chat.db.ranking_database import RankingDatabase
+
+# Define a temporary database file for testing
+TEST_DB_FILE = "test_ranking.db"
+
+
+class TestRankingDatabase(unittest.TestCase):
+    """
+    Test the RankingDatabase class
+    """
+
+    def setUp(self):
+        # Create a temporary database for testing
+        self.ranking_db = RankingDatabase(db_file=TEST_DB_FILE, debug=False)
+
+    def tearDown(self):
+        # Remove the temporary database and close connections after each test
+        self.ranking_db.close_connection()
+        os.remove(TEST_DB_FILE)
+
+    def test_create_table(self):
+        """
+        Test whether the table is created correctly.
+        """
+        conn = sqlite3.connect(TEST_DB_FILE)
+        cursor = conn.cursor()
+
+        cursor.execute(
+            "SELECT name FROM sqlite_master WHERE type='table' AND name='ranking'"
+        )
+        table_exists = cursor.fetchone()
+
+        self.assertTrue(table_exists)
+
+        # Clean up
+        cursor.close()
+        conn.close()
+
+    def test_insert_and_retrieve_ranking(self):
+        """
+        Test inserting and retrieving a ranking entry
+        """
+        question = "Which fruit is your favorite?"
+        up_ranking_answer = "Apple"
+        low_ranking_answer = "Banana"
+
+        # Insert data and get the ID
+        ranking_id = self.ranking_db.insert_ranking(
+            question, up_ranking_answer, low_ranking_answer
+        )
+
+        # Retrieve the data
+        rows = self.ranking_db.retrieve_all_question_answers()
+
+        # Check if the data was inserted correctly
+        self.assertEqual(len(rows), 1)
+        self.assertEqual(rows[0][0], ranking_id)
+        self.assertEqual(rows[0][1], question)
+        self.assertEqual(rows[0][2], up_ranking_answer)
+        self.assertEqual(rows[0][3], low_ranking_answer)
+
+    def test_save_to_csv(self):
+        """
+        Test saving data to a CSV file
+        """
+        question1 = "Which fruit is your favorite?"
+        up_ranking_answer1 = "Apple"
+        low_ranking_answer1 = "Banana"
+        question2 = "Which country would you like to visit?"
+        up_ranking_answer2 = "Japan"
+        low_ranking_answer2 = "Italy"
+
+        # Insert data
+        self.ranking_db.insert_ranking(
+            question1, up_ranking_answer1, low_ranking_answer1
+        )
+        self.ranking_db.insert_ranking(
+            question2, up_ranking_answer2, low_ranking_answer2
+        )
+
+        # Save to CSV
+        self.ranking_db.save_to_csv("test_csv_file.csv")
+
+        # Check if the CSV file was created and contains the correct data
+        self.assertTrue(os.path.exists("test_csv_file.csv"))
+
+        with open("test_csv_file.csv", "r") as file:
+            lines = file.readlines()
+
+        # Verify the CSV file content
+        self.assertEqual(len(lines), 3)  # Header + 2 rows
+        self.assertEqual(
+            lines[0].strip(), "ID,Question,Up Ranking Answer,Low Ranking Answer"
+        )
+        self.assertEqual(
+            lines[1].strip(),
+            f"1,{question1},{up_ranking_answer1},{low_ranking_answer1}",
+        )
+        self.assertEqual(
+            lines[2].strip(),
+            f"2,{question2},{up_ranking_answer2},{low_ranking_answer2}",
+        )
+
+        # Clean up
+        os.remove("test_csv_file.csv")
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/tests/chat/llm/__init__.py b/tests/chat/llm/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/chat/llm/test_abs_llm.py b/tests/chat/llm/test_abs_llm.py
new file mode 100644
index 0000000..3a8ce02
--- /dev/null
+++ b/tests/chat/llm/test_abs_llm.py
@@ -0,0 +1,40 @@
+"""
+Test the AbsLlm class.
+"""
+import unittest
+
+from pykoi.chat.llm.abs_llm import AbsLlm
+
+
+class DummyLlm(AbsLlm):
+    """
+    Dummy class for testing the abstract base class AbsLlm.
+    """
+
+    def predict(self, message: str):
+        return f"Q: {message}, A: N/A."
+
+
+class TestAbsLlm(unittest.TestCase):
+    """
+    Test the AbsLlm class.
+    """
+
+    def test_predict_abstract_method(self):
+        """
+        Test whether the abstract method `predict` raises NotImplementedError
+        """
+
+        test_message = "test"
+        llm = DummyLlm()
+        self.assertEqual(llm.predict(test_message), f"Q: {test_message}, A: N/A.")
+
+    def test_docstring(self):
+        """
+        Test whether the class has a docstring
+        """
+        self.assertIsNotNone(AbsLlm.__doc__)
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/tests/chat/llm/test_constants.py b/tests/chat/llm/test_constants.py
new file mode 100644
index 0000000..2e3e71f
--- /dev/null
+++ b/tests/chat/llm/test_constants.py
@@ -0,0 +1,30 @@
+"""
+Test the constants of the LLM module.
+"""
+import unittest
+
+from pykoi.chat.llm.constants import ModelSource
+
+
+class TestLlmName(unittest.TestCase):
+    """
+    Test the ModelSource enum.
+    """
+
+    def test_enum_values(self):
+        """
+        Test whether the enum values are defined correctly
+        """
+        self.assertEqual(ModelSource.OPENAI.value, "openai")
+        self.assertEqual(ModelSource.HUGGINGFACE.value, "huggingface")
+
+    def test_enum_attributes(self):
+        """
+        Test whether the enum attributes are defined correctly
+        """
+        self.assertEqual(ModelSource.OPENAI.name, "OPENAI")
+        self.assertEqual(ModelSource.HUGGINGFACE.name, "HUGGINGFACE")
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/tests/chat/llm/test_huggingface.py b/tests/chat/llm/test_huggingface.py
new file mode 100644
index 0000000..77e2d74
--- /dev/null
+++ b/tests/chat/llm/test_huggingface.py
@@ -0,0 +1,40 @@
+import unittest
+from unittest.mock import patch, Mock
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from pykoi.chat.llm.abs_llm import AbsLlm
+from pykoi.chat.llm.huggingface import HuggingfaceModel
+
+
+class TestHuggingfaceModel(unittest.TestCase):
+    @patch.object(AutoModelForCausalLM, "from_pretrained")
+    @patch.object(AutoTokenizer, "from_pretrained")
+    def setUp(self, mock_model, mock_tokenizer):
+        self.model_name = "gpt2"
+        self.mock_model = mock_model
+        self.mock_tokenizer = mock_tokenizer
+
+        # Mocking the pretrained model and tokenizer
+        self.mock_model.return_value = Mock()
+        self.mock_tokenizer.return_value = Mock()
+
+        self.huggingface_model = HuggingfaceModel(
+            pretrained_model_name_or_path=self.model_name
+        )
+
+    def test_name(self):
+        expected_name = f"{HuggingfaceModel.model_source}_{self.model_name}_100"
+        self.assertEqual(self.huggingface_model.name, expected_name)
+
+    @patch.object(HuggingfaceModel, "predict")
+    def test_predict(self, mock_predict):
+        mock_predict.return_value = ["Hello, how can I assist you today?"]
+        message = "Hello, chatbot!"
+        num_of_response = 1
+        response = self.huggingface_model.predict(message, num_of_response)
+        mock_predict.assert_called_once_with(message, num_of_response)
+        self.assertEqual(response, ["Hello, how can I assist you today?"])
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/tests/chat/llm/test_openai.py b/tests/chat/llm/test_openai.py
new file mode 100644
index 0000000..5cf01d8
--- /dev/null
+++ b/tests/chat/llm/test_openai.py
@@ -0,0 +1,54 @@
+"""
+Test the OpenAIModel class
+"""
+import unittest
+from unittest.mock import MagicMock, patch
+
+from pykoi.chat.llm.openai import OpenAIModel
+
+
+class TestOpenAIModel(unittest.TestCase):
+    """
+    Test the OpenAIModel class
+    """
+
+    def test_predict(self):
+        """
+        Test the predict method of the OpenAIModel class
+        """
+        # Test predicting the next word based on a given message
+        message = "What is the meaning of life?"
+        predicted_word = "42"
+
+        # Mock the OpenAI.Completion.create behavior
+        mock_response = MagicMock()
+        mock_response.choices = [MagicMock()]
+        mock_response.choices[0].text = f"Answer: {predicted_word}"
+        openai_completion_create_mock = MagicMock(return_value=mock_response)
+
+        # Patch the OpenAI.Completion.create method to use the mocked version
+        with patch(
+            "pykoi.chat.llm.openai.openai.Completion.create", openai_completion_create_mock
+        ):
+            openai_model = OpenAIModel(
+                api_key="fake_api_key",
+                engine="davinci",
+                max_tokens=100,
+                temperature=0.5,
+            )
+            result = openai_model.predict(message, 1)
+
+        # Check if the OpenAI.Completion.create method was called with the correct arguments
+        openai_completion_create_mock.assert_called_once_with(
+            engine="davinci",
+            prompt=f"Question: {message}\nAnswer:",
+            max_tokens=100,
+            n=1,
+            stop="\n",
+            temperature=0.5,
+        )
+        self.assertEqual(result[0], f"Answer: {predicted_word}")
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/tests/component/__init__.py b/tests/component/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/component/test_base.py b/tests/component/test_base.py
new file mode 100644
index 0000000..84df2f9
--- /dev/null
+++ b/tests/component/test_base.py
@@ -0,0 +1,77 @@
+"""Tests for the base module."""
+import unittest
+
+from pykoi.component.base import Component, DataSource
+
+
+class TestDataSource(unittest.TestCase):
+    """
+    Unit test class for the DataSource class.
+
+    This class tests the initialization of the DataSource class by creating a DataSource instance
+    with a specific id and a fetch function. It then asserts that the id and the data fetched by
+    the fetch function are as expected.
+
+    Attributes:
+        fetch_func (function): A function that returns the data to be fetched.
+        ds (DataSource): An instance of the DataSource class.
+
+    Methods:
+        test_init: Tests the initialization of the DataSource class.
+    """
+
+    def test_init(self):
+        """
+        Tests the initialization of the DataSource class.
+
+        This method creates a DataSource instance with a specific id and a fetch function.
+        It then asserts that the id and the data fetched by the fetch function are as expected.
+        """
+
+        def fetch_func():
+            return "data"
+
+        data_source = DataSource("test_id", fetch_func)
+
+        self.assertEqual(data_source.id, "test_id")
+        self.assertEqual(data_source.fetch_func(), "data")
+
+
+class TestComponent(unittest.TestCase):
+    """
+    Unit test class for the Component class.
+
+    This class tests the initialization of the Component class by creating a Component instance
+    with a specific fetch function, a svelte component, and properties. It then asserts that the id,
+    the data fetched by the fetch function, the svelte component, and the properties are as expected.
+
+    Attributes:
+        fetch_func (function): A function that returns the data to be fetched.
+        comp (Component): An instance of the Component class.
+
+    Methods:
+        test_init: Tests the initialization of the Component class.
+    """
+
+    def test_init(self):
+        """
+        Tests the initialization of the Component class.
+
+        This method creates a Component instance with a specific fetch function, a svelte component,
+        and properties. It then asserts that the id, the data fetched by the fetch function, the svelte
+        component, and the properties are as expected.
+        """
+
+        def fetch_func():
+            return "data"
+
+        comp = Component(fetch_func, "TestComponent", prop1="value1", prop2="value2")
+
+        self.assertIsNotNone(comp.id)
+        self.assertEqual(comp.data_source.fetch_func(), "data")
+        self.assertEqual(comp.svelte_component, "TestComponent")
+        self.assertEqual(comp.props, {"prop1": "value1", "prop2": "value2"})
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/tests/test_application.py b/tests/test_application.py
new file mode 100644
index 0000000..4f5f3b9
--- /dev/null
+++ b/tests/test_application.py
@@ -0,0 +1,36 @@
+"""
+Test the Application class.
+"""
+
+import os
+import unittest
+from unittest.mock import MagicMock, patch
+
+from pykoi.application import Application
+
+
+class TestApplication(unittest.TestCase):
+    """
+    Unit test class for the Application class.
+    """
+
+    def setUp(self):
+        self.app = Application(share=False, debug=False)
+
+    def test_add_component(self):
+        """
+        Tests adding a component to the application.
+        """
+        component = MagicMock()
+        component.data_source = MagicMock()
+        component.id = "test_component"
+
+        self.app.add_component(component)
+
+        # Check if the component and its data source are added to the application
+        self.assertIn("test_component", self.app.data_sources)
+        self.assertIn(component, [c["component"] for c in self.app.components])
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/tests/test_state.py b/tests/test_state.py
new file mode 100644
index 0000000..3854817
--- /dev/null
+++ b/tests/test_state.py
@@ -0,0 +1,86 @@
+"""
+Tests for the State class.
+"""
+import unittest
+
+from pykoi.state import State, Store
+
+
+class TestState(unittest.TestCase):
+    """
+    Unit test class for the State class.
+    """
+
+    def test_get_attribute(self):
+        """
+        Tests getting an attribute from the state.
+        """
+        state = State()
+        state.state = {"test_attr": 42}
+
+        self.assertEqual(state.test_attr, 42)
+
+    def test_get_non_existing_attribute(self):
+        """
+        Tests getting a non-existing attribute from the state.
+        """
+        state = State()
+
+        with self.assertRaises(AttributeError):
+            attr = state.non_existing_attr
+
+    def test_set_attribute(self):
+        """
+        Tests setting an attribute in the state.
+        """
+        state = State()
+        state.test_attr = 42
+
+        self.assertEqual(state.state["test_attr"], 42)
+
+    def test_call_method(self):
+        """
+        Tests calling a method from the state.
+        """
+        state = State()
+        state.state = {"test_method": lambda x: x * 2}
+
+        result = state.test_method(3)
+        self.assertEqual(result, 6)
+
+
+class TestStore(unittest.TestCase):
+    """
+    Unit test class for the Store class.
+    """
+
+    def test_increment(self):
+        """
+        Tests incrementing the count.
+        """
+        store = Store()
+        store.increment()
+
+        self.assertEqual(store.count, 5)
+
+    def test_decrement(self):
+        """
+        Tests decrementing the count.
+        """
+        store = Store()
+        store.decrement()
+
+        self.assertEqual(store.count, 3)
+
+    def test_hello(self):
+        """
+        Tests the hello method.
+        """
+        store = Store()
+        hello_msg = store.hello()
+
+        self.assertEqual(hello_msg, "hello jared")
+
+
+if __name__ == "__main__":
+    unittest.main()