Skip to content

Commit

Permalink
Update: Adding more test case to improve coverage!
Browse files Browse the repository at this point in the history
  • Loading branch information
amindadgar committed Feb 1, 2024
1 parent ce46c03 commit e08197a
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 3 deletions.
1 change: 0 additions & 1 deletion bot/retrievers/retrieve_similar_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def __init__(
self._vector_store = vector_store
self._embed_model = embed_model
self._similarity_top_k = similarity_top_k
super().__init__()

def query_db(
self, query: str, filters: list[dict[str, str]] | None = None
Expand Down
4 changes: 2 additions & 2 deletions bot/retrievers/summary_retriever_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,13 @@ def get_similar_nodes(
return nodes

def _setup_index(
self, table_name: str, dbname: str, embedding_model: BaseEmbedding
self, table_name: str, dbname: str, embedding_model: BaseEmbedding, testing: bool = False,
) -> VectorStoreIndex:
"""
setup the llama_index VectorStoreIndex
"""
pg_vector_access = PGVectorAccess(
table_name=table_name, dbname=dbname, embed_model=embedding_model
table_name=table_name, dbname=dbname, embed_model=embedding_model, testing=testing
)
index = pg_vector_access.load_index()
return index
36 changes: 36 additions & 0 deletions tests/integration/test_retrieve_similar_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from unittest import TestCase
from unittest.mock import MagicMock

from collections import namedtuple

from bot.retrievers.retrieve_similar_nodes import RetrieveSimilarNodes
from llama_index.schema import NodeWithScore, TextNode


class TestRetrieveSimilarNodes(TestCase):
def setUp(self):
self.table_name = "sample_table"
self.dbname = "community_some_id"

self.vector_store = MagicMock()
self.embed_model = MagicMock()
self.retriever = RetrieveSimilarNodes(
vector_store=self.vector_store,
similarity_top_k=5,
embed_model=self.embed_model
)

def test_init(self):
self.assertEqual(self.retriever._similarity_top_k, 5)
self.assertEqual(self.vector_store, self.retriever._vector_store)

def test_get_nodes_with_score(self):
# Test the _get_nodes_with_score private method
query_result = MagicMock()
query_result.nodes = [TextNode(), TextNode(), TextNode()]
query_result.similarities = [0.8, 0.9, 0.7]

result = self.retriever._get_nodes_with_score(query_result)

self.assertEqual(len(result), 3)
self.assertAlmostEqual(result[0].score, 0.8, delta=0.001)
10 changes: 10 additions & 0 deletions tests/unit/test_summary_retriever_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,13 @@ def test_initialize_class(self):
nodes = base_summary_search.get_similar_nodes(query="what is samplesample?")
self.assertIsInstance(nodes, list)
self.assertIsInstance(nodes[0], NodeWithScore)

def test_setup_index(self):
table_name = "your_table_name"
dbname = "your_db_name"
embedding_model = MagicMock()
search_instance = BaseSummarySearch(table_name, dbname, embedding_model)

index = search_instance._setup_index(table_name, dbname, embedding_model, testing=True)
self.assertIsNotNone(index)
self.assertIsInstance(index, VectorStoreIndex)

0 comments on commit e08197a

Please sign in to comment.