From 21ed6d4bc23344e5e5cbd7091d4391534ec3dd8d Mon Sep 17 00:00:00 2001 From: Mohammad Amin <dadgaramin96@gmail.com> Date: Tue, 23 Apr 2024 10:28:42 +0330 Subject: [PATCH 1/4] feat: Added github query engine! --- bot/retrievers/custom_retriever.py | 59 +++++++++++++++++++ subquery.py | 15 ++++- tests/unit/test_base_engine.py | 22 +++++++ tests/unit/test_github_query_engine.py | 15 +++++ utils/query_engine/__init__.py | 1 + utils/query_engine/base_engine.py | 25 ++++++++ utils/query_engine/github.py | 30 ++++++++++ .../level_based_platform_query_engine.py | 24 +------- 8 files changed, 167 insertions(+), 24 deletions(-) create mode 100644 bot/retrievers/custom_retriever.py create mode 100644 tests/unit/test_base_engine.py create mode 100644 tests/unit/test_github_query_engine.py create mode 100644 utils/query_engine/base_engine.py create mode 100644 utils/query_engine/github.py diff --git a/bot/retrievers/custom_retriever.py b/bot/retrievers/custom_retriever.py new file mode 100644 index 0000000..deac4ac --- /dev/null +++ b/bot/retrievers/custom_retriever.py @@ -0,0 +1,59 @@ +""" +We're going to override the `_build_node_list_from_query_result` since it is raising errors having the llama-index legacy & newer version together +""" + +from llama_index.core.data_structs.data_structs import IndexDict +from llama_index.core.indices.utils import log_vector_store_query_result +from llama_index.core.indices.vector_store.retrievers.retriever import ( + VectorIndexRetriever, +) +from llama_index.core.schema import Node, NodeWithScore, ObjectType +from llama_index.core.vector_stores.types import VectorStoreQueryResult + + +class CustomVectorStoreRetriever(VectorIndexRetriever): + + def _build_node_list_from_query_result( + self, query_result: VectorStoreQueryResult + ) -> list[NodeWithScore]: + if query_result.nodes is None: + # NOTE: vector store does not keep text and returns node indices. + # Need to recover all nodes from docstore + if query_result.ids is None: + raise ValueError( + "Vector store query result should return at " + "least one of nodes or ids." + ) + assert isinstance(self._index.index_struct, IndexDict) + node_ids = [ + self._index.index_struct.nodes_dict[idx] for idx in query_result.ids + ] + nodes = self._docstore.get_nodes(node_ids) + query_result.nodes = nodes + else: + # NOTE: vector store keeps text, returns nodes. + # Only need to recover image or index nodes from docstore + for i in range(len(query_result.nodes)): + source_node = query_result.nodes[i].source_node + if (not self._vector_store.stores_text) or ( + source_node is not None and source_node.node_type != ObjectType.TEXT + ): + node_id = query_result.nodes[i].node_id + if self._docstore.document_exists(node_id): + query_result.nodes[i] = self._docstore.get_node( + node_id + ) # type: ignore[index] + + log_vector_store_query_result(query_result) + node_with_scores: list[NodeWithScore] = [] + for ind, node in enumerate(query_result.nodes): + score: float | None = None + if query_result.similarities is not None: + score = query_result.similarities[ind] + # This is the part we updated + node_new = Node.from_dict(node.to_dict()) + node_with_score = NodeWithScore(node=node_new, score=score) + + node_with_scores.append(node_with_score) + + return node_with_scores diff --git a/subquery.py b/subquery.py index 5710999..a8c31ff 100644 --- a/subquery.py +++ b/subquery.py @@ -9,6 +9,7 @@ from tc_hivemind_backend.embeddings.cohere import CohereEmbedding from utils.query_engine import ( DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL, + GitHubQueryEngine, prepare_discord_engine_auto_filter, ) @@ -64,11 +65,11 @@ def query_multiple_source( tools: list[ToolMetadata] = [] discord_query_engine: BaseQueryEngine + github_query_engine: BaseQueryEngine # discourse_query_engine: BaseQueryEngine # gdrive_query_engine: BaseQueryEngine # notion_query_engine: BaseQueryEngine # telegram_query_engine: BaseQueryEngine - # github_query_engine: BaseQueryEngine # query engine perparation # tools_metadata and query_engine_tools @@ -99,7 +100,17 @@ def query_multiple_source( if telegram: raise NotImplementedError if github: - raise NotImplementedError + github_query_engine = GitHubQueryEngine(community_id=community_id).prepare() + tool_metadata = ToolMetadata( + name="GitHub", + description="Hosts code repositories and project materials from the GitHub platform.", + ) + query_engine_tools.append( + QueryEngineTool( + query_engine=github_query_engine, + metadata=tool_metadata, + ) + ) embed_model = CohereEmbedding() llm = OpenAI("gpt-3.5-turbo") diff --git a/tests/unit/test_base_engine.py b/tests/unit/test_base_engine.py new file mode 100644 index 0000000..f08119f --- /dev/null +++ b/tests/unit/test_base_engine.py @@ -0,0 +1,22 @@ +from unittest import TestCase + +from utils.query_engine.base_engine import BaseEngine + + +class TestBaseEngine(TestCase): + + def test_setup_vector_store_index(self): + """ + Tests that _setup_vector_store_index creates a PGVectorAccess object + and calls its load_index method. + """ + platform_table_name = "test_table" + dbname = "test_db" + + base_engine = BaseEngine._setup_vector_store_index( + platform_table_name=platform_table_name, + dbname=dbname, + testing=True, + ) + self.assertIn(dbname, base_engine.vector_store.connection_string) + self.assertEqual(base_engine.vector_store.table_name, platform_table_name) diff --git a/tests/unit/test_github_query_engine.py b/tests/unit/test_github_query_engine.py new file mode 100644 index 0000000..1361398 --- /dev/null +++ b/tests/unit/test_github_query_engine.py @@ -0,0 +1,15 @@ +from unittest import TestCase + +from bot.retrievers.custom_retriever import CustomVectorStoreRetriever +from utils.query_engine import GitHubQueryEngine + + +class TestGitHubQueryEngine(TestCase): + def setUp(self) -> None: + community_id = "sample_community" + self.github_query_engine = GitHubQueryEngine(community_id) + + def test_prepare_engine(self): + github_query_engine = self.github_query_engine.prepare(testing=True) + print(github_query_engine.__dict__) + self.assertIsInstance(github_query_engine.retriever, CustomVectorStoreRetriever) diff --git a/utils/query_engine/__init__.py b/utils/query_engine/__init__.py index a988e54..ea544ac 100644 --- a/utils/query_engine/__init__.py +++ b/utils/query_engine/__init__.py @@ -1,3 +1,4 @@ # flake8: noqa from .prepare_discord_query_engine import prepare_discord_engine_auto_filter from .subquery_gen_prompt import DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL +from .github import GitHubQueryEngine diff --git a/utils/query_engine/base_engine.py b/utils/query_engine/base_engine.py new file mode 100644 index 0000000..ed59008 --- /dev/null +++ b/utils/query_engine/base_engine.py @@ -0,0 +1,25 @@ +from llama_index.core import VectorStoreIndex +from tc_hivemind_backend.embeddings import CohereEmbedding +from tc_hivemind_backend.pg_vector_access import PGVectorAccess + + +class BaseEngine: + + @classmethod + def _setup_vector_store_index( + cls, + platform_table_name: str, + dbname: str, + testing: bool = False, + ) -> VectorStoreIndex: + """ + prepare the vector_store for querying data + """ + pg_vector = PGVectorAccess( + table_name=platform_table_name, + dbname=dbname, + testing=testing, + embed_model=CohereEmbedding(), + ) + index = pg_vector.load_index() + return index diff --git a/utils/query_engine/github.py b/utils/query_engine/github.py new file mode 100644 index 0000000..e9822e8 --- /dev/null +++ b/utils/query_engine/github.py @@ -0,0 +1,30 @@ +from utils.query_engine.base_engine import BaseEngine +from llama_index.core.query_engine import RetrieverQueryEngine +from bot.retrievers.custom_retriever import CustomVectorStoreRetriever +from llama_index.core import get_response_synthesizer +from bot.retrievers.utils.load_hyperparams import load_hyperparams + + +class GitHubQueryEngine(BaseEngine): + platform_name = "github" + + def __init__(self, community_id: str) -> None: + dbname = f"community_{community_id}" + self.dbname = dbname + + def prepare(self, testing=False): + index = self._setup_vector_store_index( + platform_table_name=self.platform_name, + dbname=self.dbname, + testing=testing, + ) + _, similarity_top_k, _ = load_hyperparams() + + retriever = CustomVectorStoreRetriever( + index=index, similarity_top_k=similarity_top_k + ) + query_engine = RetrieverQueryEngine( + retriever=retriever, + response_synthesizer=get_response_synthesizer(), + ) + return query_engine diff --git a/utils/query_engine/level_based_platform_query_engine.py b/utils/query_engine/level_based_platform_query_engine.py index 82dc7a7..6647876 100644 --- a/utils/query_engine/level_based_platform_query_engine.py +++ b/utils/query_engine/level_based_platform_query_engine.py @@ -13,8 +13,7 @@ from llama_index.core.retrievers import BaseRetriever from llama_index.core.schema import NodeWithScore from llama_index.llms.openai import OpenAI -from tc_hivemind_backend.embeddings.cohere import CohereEmbedding -from tc_hivemind_backend.pg_vector_access import PGVectorAccess +from utils.query_engine.base_engine import BaseEngine from utils.query_engine.level_based_platforms_util import LevelBasedPlatformUtils qa_prompt = PromptTemplate( @@ -29,7 +28,7 @@ ) -class LevelBasedPlatformQueryEngine(CustomQueryEngine): +class LevelBasedPlatformQueryEngine(CustomQueryEngine, BaseEngine): retriever: BaseRetriever response_synthesizer: BaseSynthesizer llm: OpenAI @@ -335,22 +334,3 @@ def _prepare_context_str( logging.debug(f"context_str of prompt\n" f"{context_str}") return context_str - - @classmethod - def _setup_vector_store_index( - cls, - platform_table_name: str, - dbname: str, - testing: bool = False, - ) -> VectorStoreIndex: - """ - prepare the vector_store for querying data - """ - pg_vector = PGVectorAccess( - table_name=platform_table_name, - dbname=dbname, - testing=testing, - embed_model=CohereEmbedding(), - ) - index = pg_vector.load_index() - return index From 261c8af78513b7a0873f95acedb437a3cefc5e09 Mon Sep 17 00:00:00 2001 From: Mohammad Amin <dadgaramin96@gmail.com> Date: Tue, 23 Apr 2024 10:39:16 +0330 Subject: [PATCH 2/4] fix: linter issues! --- bot/retrievers/custom_retriever.py | 4 ++-- tests/unit/test_base_engine.py | 1 - utils/query_engine/__init__.py | 2 +- utils/query_engine/base_engine.py | 1 - utils/query_engine/github.py | 6 +++--- 5 files changed, 6 insertions(+), 8 deletions(-) diff --git a/bot/retrievers/custom_retriever.py b/bot/retrievers/custom_retriever.py index deac4ac..bb05ec0 100644 --- a/bot/retrievers/custom_retriever.py +++ b/bot/retrievers/custom_retriever.py @@ -1,5 +1,6 @@ """ -We're going to override the `_build_node_list_from_query_result` since it is raising errors having the llama-index legacy & newer version together +We're going to override the `_build_node_list_from_query_result` +since it is raising errors having the llama-index legacy & newer version together """ from llama_index.core.data_structs.data_structs import IndexDict @@ -12,7 +13,6 @@ class CustomVectorStoreRetriever(VectorIndexRetriever): - def _build_node_list_from_query_result( self, query_result: VectorStoreQueryResult ) -> list[NodeWithScore]: diff --git a/tests/unit/test_base_engine.py b/tests/unit/test_base_engine.py index f08119f..09820a6 100644 --- a/tests/unit/test_base_engine.py +++ b/tests/unit/test_base_engine.py @@ -4,7 +4,6 @@ class TestBaseEngine(TestCase): - def test_setup_vector_store_index(self): """ Tests that _setup_vector_store_index creates a PGVectorAccess object diff --git a/utils/query_engine/__init__.py b/utils/query_engine/__init__.py index ea544ac..ce734c4 100644 --- a/utils/query_engine/__init__.py +++ b/utils/query_engine/__init__.py @@ -1,4 +1,4 @@ # flake8: noqa +from .github import GitHubQueryEngine from .prepare_discord_query_engine import prepare_discord_engine_auto_filter from .subquery_gen_prompt import DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL -from .github import GitHubQueryEngine diff --git a/utils/query_engine/base_engine.py b/utils/query_engine/base_engine.py index ed59008..72bda45 100644 --- a/utils/query_engine/base_engine.py +++ b/utils/query_engine/base_engine.py @@ -4,7 +4,6 @@ class BaseEngine: - @classmethod def _setup_vector_store_index( cls, diff --git a/utils/query_engine/github.py b/utils/query_engine/github.py index e9822e8..b5adf61 100644 --- a/utils/query_engine/github.py +++ b/utils/query_engine/github.py @@ -1,8 +1,8 @@ -from utils.query_engine.base_engine import BaseEngine -from llama_index.core.query_engine import RetrieverQueryEngine from bot.retrievers.custom_retriever import CustomVectorStoreRetriever -from llama_index.core import get_response_synthesizer from bot.retrievers.utils.load_hyperparams import load_hyperparams +from llama_index.core import get_response_synthesizer +from llama_index.core.query_engine import RetrieverQueryEngine +from utils.query_engine.base_engine import BaseEngine class GitHubQueryEngine(BaseEngine): From 6077677bfdc2c5e3f9fd2544fc01ed787d6a22b9 Mon Sep 17 00:00:00 2001 From: Mohammad Amin <dadgaramin96@gmail.com> Date: Tue, 23 Apr 2024 10:49:23 +0330 Subject: [PATCH 3/4] feat: Enabled querying github data source! --- celery_app/tasks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/celery_app/tasks.py b/celery_app/tasks.py index 9dfe16d..eb8f99b 100644 --- a/celery_app/tasks.py +++ b/celery_app/tasks.py @@ -75,6 +75,7 @@ def ask_question_auto_search( query=question, community_id=community_id, discord=True, + github=True, ) # source_nodes_dict: list[dict[str, Any]] = [] From 25242dbaf84288de336450b8750a8ab1acd95f76 Mon Sep 17 00:00:00 2001 From: Mohammad Amin <dadgaramin96@gmail.com> Date: Tue, 23 Apr 2024 12:20:20 +0330 Subject: [PATCH 4/4] feat: Updated github queryengine description! --- subquery.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/subquery.py b/subquery.py index a8c31ff..c2db9c7 100644 --- a/subquery.py +++ b/subquery.py @@ -103,7 +103,10 @@ def query_multiple_source( github_query_engine = GitHubQueryEngine(community_id=community_id).prepare() tool_metadata = ToolMetadata( name="GitHub", - description="Hosts code repositories and project materials from the GitHub platform.", + description=( + "Hosts commits and conversations from Github issues and" + " pull requests from the selected repositories" + ), ) query_engine_tools.append( QueryEngineTool(