From 21ed6d4bc23344e5e5cbd7091d4391534ec3dd8d Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Tue, 23 Apr 2024 10:28:42 +0330 Subject: [PATCH] feat: Added github query engine! --- bot/retrievers/custom_retriever.py | 59 +++++++++++++++++++ subquery.py | 15 ++++- tests/unit/test_base_engine.py | 22 +++++++ tests/unit/test_github_query_engine.py | 15 +++++ utils/query_engine/__init__.py | 1 + utils/query_engine/base_engine.py | 25 ++++++++ utils/query_engine/github.py | 30 ++++++++++ .../level_based_platform_query_engine.py | 24 +------- 8 files changed, 167 insertions(+), 24 deletions(-) create mode 100644 bot/retrievers/custom_retriever.py create mode 100644 tests/unit/test_base_engine.py create mode 100644 tests/unit/test_github_query_engine.py create mode 100644 utils/query_engine/base_engine.py create mode 100644 utils/query_engine/github.py diff --git a/bot/retrievers/custom_retriever.py b/bot/retrievers/custom_retriever.py new file mode 100644 index 0000000..deac4ac --- /dev/null +++ b/bot/retrievers/custom_retriever.py @@ -0,0 +1,59 @@ +""" +We're going to override the `_build_node_list_from_query_result` since it is raising errors having the llama-index legacy & newer version together +""" + +from llama_index.core.data_structs.data_structs import IndexDict +from llama_index.core.indices.utils import log_vector_store_query_result +from llama_index.core.indices.vector_store.retrievers.retriever import ( + VectorIndexRetriever, +) +from llama_index.core.schema import Node, NodeWithScore, ObjectType +from llama_index.core.vector_stores.types import VectorStoreQueryResult + + +class CustomVectorStoreRetriever(VectorIndexRetriever): + + def _build_node_list_from_query_result( + self, query_result: VectorStoreQueryResult + ) -> list[NodeWithScore]: + if query_result.nodes is None: + # NOTE: vector store does not keep text and returns node indices. + # Need to recover all nodes from docstore + if query_result.ids is None: + raise ValueError( + "Vector store query result should return at " + "least one of nodes or ids." + ) + assert isinstance(self._index.index_struct, IndexDict) + node_ids = [ + self._index.index_struct.nodes_dict[idx] for idx in query_result.ids + ] + nodes = self._docstore.get_nodes(node_ids) + query_result.nodes = nodes + else: + # NOTE: vector store keeps text, returns nodes. + # Only need to recover image or index nodes from docstore + for i in range(len(query_result.nodes)): + source_node = query_result.nodes[i].source_node + if (not self._vector_store.stores_text) or ( + source_node is not None and source_node.node_type != ObjectType.TEXT + ): + node_id = query_result.nodes[i].node_id + if self._docstore.document_exists(node_id): + query_result.nodes[i] = self._docstore.get_node( + node_id + ) # type: ignore[index] + + log_vector_store_query_result(query_result) + node_with_scores: list[NodeWithScore] = [] + for ind, node in enumerate(query_result.nodes): + score: float | None = None + if query_result.similarities is not None: + score = query_result.similarities[ind] + # This is the part we updated + node_new = Node.from_dict(node.to_dict()) + node_with_score = NodeWithScore(node=node_new, score=score) + + node_with_scores.append(node_with_score) + + return node_with_scores diff --git a/subquery.py b/subquery.py index 5710999..a8c31ff 100644 --- a/subquery.py +++ b/subquery.py @@ -9,6 +9,7 @@ from tc_hivemind_backend.embeddings.cohere import CohereEmbedding from utils.query_engine import ( DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL, + GitHubQueryEngine, prepare_discord_engine_auto_filter, ) @@ -64,11 +65,11 @@ def query_multiple_source( tools: list[ToolMetadata] = [] discord_query_engine: BaseQueryEngine + github_query_engine: BaseQueryEngine # discourse_query_engine: BaseQueryEngine # gdrive_query_engine: BaseQueryEngine # notion_query_engine: BaseQueryEngine # telegram_query_engine: BaseQueryEngine - # github_query_engine: BaseQueryEngine # query engine perparation # tools_metadata and query_engine_tools @@ -99,7 +100,17 @@ def query_multiple_source( if telegram: raise NotImplementedError if github: - raise NotImplementedError + github_query_engine = GitHubQueryEngine(community_id=community_id).prepare() + tool_metadata = ToolMetadata( + name="GitHub", + description="Hosts code repositories and project materials from the GitHub platform.", + ) + query_engine_tools.append( + QueryEngineTool( + query_engine=github_query_engine, + metadata=tool_metadata, + ) + ) embed_model = CohereEmbedding() llm = OpenAI("gpt-3.5-turbo") diff --git a/tests/unit/test_base_engine.py b/tests/unit/test_base_engine.py new file mode 100644 index 0000000..f08119f --- /dev/null +++ b/tests/unit/test_base_engine.py @@ -0,0 +1,22 @@ +from unittest import TestCase + +from utils.query_engine.base_engine import BaseEngine + + +class TestBaseEngine(TestCase): + + def test_setup_vector_store_index(self): + """ + Tests that _setup_vector_store_index creates a PGVectorAccess object + and calls its load_index method. + """ + platform_table_name = "test_table" + dbname = "test_db" + + base_engine = BaseEngine._setup_vector_store_index( + platform_table_name=platform_table_name, + dbname=dbname, + testing=True, + ) + self.assertIn(dbname, base_engine.vector_store.connection_string) + self.assertEqual(base_engine.vector_store.table_name, platform_table_name) diff --git a/tests/unit/test_github_query_engine.py b/tests/unit/test_github_query_engine.py new file mode 100644 index 0000000..1361398 --- /dev/null +++ b/tests/unit/test_github_query_engine.py @@ -0,0 +1,15 @@ +from unittest import TestCase + +from bot.retrievers.custom_retriever import CustomVectorStoreRetriever +from utils.query_engine import GitHubQueryEngine + + +class TestGitHubQueryEngine(TestCase): + def setUp(self) -> None: + community_id = "sample_community" + self.github_query_engine = GitHubQueryEngine(community_id) + + def test_prepare_engine(self): + github_query_engine = self.github_query_engine.prepare(testing=True) + print(github_query_engine.__dict__) + self.assertIsInstance(github_query_engine.retriever, CustomVectorStoreRetriever) diff --git a/utils/query_engine/__init__.py b/utils/query_engine/__init__.py index a988e54..ea544ac 100644 --- a/utils/query_engine/__init__.py +++ b/utils/query_engine/__init__.py @@ -1,3 +1,4 @@ # flake8: noqa from .prepare_discord_query_engine import prepare_discord_engine_auto_filter from .subquery_gen_prompt import DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL +from .github import GitHubQueryEngine diff --git a/utils/query_engine/base_engine.py b/utils/query_engine/base_engine.py new file mode 100644 index 0000000..ed59008 --- /dev/null +++ b/utils/query_engine/base_engine.py @@ -0,0 +1,25 @@ +from llama_index.core import VectorStoreIndex +from tc_hivemind_backend.embeddings import CohereEmbedding +from tc_hivemind_backend.pg_vector_access import PGVectorAccess + + +class BaseEngine: + + @classmethod + def _setup_vector_store_index( + cls, + platform_table_name: str, + dbname: str, + testing: bool = False, + ) -> VectorStoreIndex: + """ + prepare the vector_store for querying data + """ + pg_vector = PGVectorAccess( + table_name=platform_table_name, + dbname=dbname, + testing=testing, + embed_model=CohereEmbedding(), + ) + index = pg_vector.load_index() + return index diff --git a/utils/query_engine/github.py b/utils/query_engine/github.py new file mode 100644 index 0000000..e9822e8 --- /dev/null +++ b/utils/query_engine/github.py @@ -0,0 +1,30 @@ +from utils.query_engine.base_engine import BaseEngine +from llama_index.core.query_engine import RetrieverQueryEngine +from bot.retrievers.custom_retriever import CustomVectorStoreRetriever +from llama_index.core import get_response_synthesizer +from bot.retrievers.utils.load_hyperparams import load_hyperparams + + +class GitHubQueryEngine(BaseEngine): + platform_name = "github" + + def __init__(self, community_id: str) -> None: + dbname = f"community_{community_id}" + self.dbname = dbname + + def prepare(self, testing=False): + index = self._setup_vector_store_index( + platform_table_name=self.platform_name, + dbname=self.dbname, + testing=testing, + ) + _, similarity_top_k, _ = load_hyperparams() + + retriever = CustomVectorStoreRetriever( + index=index, similarity_top_k=similarity_top_k + ) + query_engine = RetrieverQueryEngine( + retriever=retriever, + response_synthesizer=get_response_synthesizer(), + ) + return query_engine diff --git a/utils/query_engine/level_based_platform_query_engine.py b/utils/query_engine/level_based_platform_query_engine.py index 82dc7a7..6647876 100644 --- a/utils/query_engine/level_based_platform_query_engine.py +++ b/utils/query_engine/level_based_platform_query_engine.py @@ -13,8 +13,7 @@ from llama_index.core.retrievers import BaseRetriever from llama_index.core.schema import NodeWithScore from llama_index.llms.openai import OpenAI -from tc_hivemind_backend.embeddings.cohere import CohereEmbedding -from tc_hivemind_backend.pg_vector_access import PGVectorAccess +from utils.query_engine.base_engine import BaseEngine from utils.query_engine.level_based_platforms_util import LevelBasedPlatformUtils qa_prompt = PromptTemplate( @@ -29,7 +28,7 @@ ) -class LevelBasedPlatformQueryEngine(CustomQueryEngine): +class LevelBasedPlatformQueryEngine(CustomQueryEngine, BaseEngine): retriever: BaseRetriever response_synthesizer: BaseSynthesizer llm: OpenAI @@ -335,22 +334,3 @@ def _prepare_context_str( logging.debug(f"context_str of prompt\n" f"{context_str}") return context_str - - @classmethod - def _setup_vector_store_index( - cls, - platform_table_name: str, - dbname: str, - testing: bool = False, - ) -> VectorStoreIndex: - """ - prepare the vector_store for querying data - """ - pg_vector = PGVectorAccess( - table_name=platform_table_name, - dbname=dbname, - testing=testing, - embed_model=CohereEmbedding(), - ) - index = pg_vector.load_index() - return index