Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added github query engine! #50

Merged
merged 4 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions bot/retrievers/custom_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
We're going to override the `_build_node_list_from_query_result`
since it is raising errors having the llama-index legacy & newer version together
"""

from llama_index.core.data_structs.data_structs import IndexDict
from llama_index.core.indices.utils import log_vector_store_query_result
from llama_index.core.indices.vector_store.retrievers.retriever import (
VectorIndexRetriever,
)
from llama_index.core.schema import Node, NodeWithScore, ObjectType
from llama_index.core.vector_stores.types import VectorStoreQueryResult


class CustomVectorStoreRetriever(VectorIndexRetriever):
def _build_node_list_from_query_result(
self, query_result: VectorStoreQueryResult
) -> list[NodeWithScore]:
if query_result.nodes is None:
# NOTE: vector store does not keep text and returns node indices.
# Need to recover all nodes from docstore
if query_result.ids is None:
raise ValueError(
"Vector store query result should return at "
"least one of nodes or ids."
)
assert isinstance(self._index.index_struct, IndexDict)
node_ids = [
self._index.index_struct.nodes_dict[idx] for idx in query_result.ids
]
nodes = self._docstore.get_nodes(node_ids)
query_result.nodes = nodes
else:
# NOTE: vector store keeps text, returns nodes.
# Only need to recover image or index nodes from docstore
for i in range(len(query_result.nodes)):
source_node = query_result.nodes[i].source_node
if (not self._vector_store.stores_text) or (
source_node is not None and source_node.node_type != ObjectType.TEXT
):
node_id = query_result.nodes[i].node_id
if self._docstore.document_exists(node_id):
query_result.nodes[i] = self._docstore.get_node(
node_id
) # type: ignore[index]

log_vector_store_query_result(query_result)
node_with_scores: list[NodeWithScore] = []
for ind, node in enumerate(query_result.nodes):
score: float | None = None
if query_result.similarities is not None:
score = query_result.similarities[ind]
# This is the part we updated
node_new = Node.from_dict(node.to_dict())
node_with_score = NodeWithScore(node=node_new, score=score)

node_with_scores.append(node_with_score)

return node_with_scores
1 change: 1 addition & 0 deletions celery_app/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def ask_question_auto_search(
query=question,
community_id=community_id,
discord=True,
github=True,
)

# source_nodes_dict: list[dict[str, Any]] = []
Expand Down
18 changes: 16 additions & 2 deletions subquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tc_hivemind_backend.embeddings.cohere import CohereEmbedding
from utils.query_engine import (
DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL,
GitHubQueryEngine,
prepare_discord_engine_auto_filter,
)

Expand Down Expand Up @@ -64,11 +65,11 @@ def query_multiple_source(
tools: list[ToolMetadata] = []

discord_query_engine: BaseQueryEngine
github_query_engine: BaseQueryEngine
# discourse_query_engine: BaseQueryEngine
# gdrive_query_engine: BaseQueryEngine
# notion_query_engine: BaseQueryEngine
# telegram_query_engine: BaseQueryEngine
# github_query_engine: BaseQueryEngine

# query engine perparation
# tools_metadata and query_engine_tools
Expand Down Expand Up @@ -99,7 +100,20 @@ def query_multiple_source(
if telegram:
raise NotImplementedError
if github:
raise NotImplementedError
github_query_engine = GitHubQueryEngine(community_id=community_id).prepare()
tool_metadata = ToolMetadata(
name="GitHub",
description=(
"Hosts commits and conversations from Github issues and"
" pull requests from the selected repositories"
),
)
query_engine_tools.append(
QueryEngineTool(
query_engine=github_query_engine,
metadata=tool_metadata,
)
)

embed_model = CohereEmbedding()
llm = OpenAI("gpt-3.5-turbo")
Expand Down
21 changes: 21 additions & 0 deletions tests/unit/test_base_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from unittest import TestCase

from utils.query_engine.base_engine import BaseEngine


class TestBaseEngine(TestCase):
def test_setup_vector_store_index(self):
"""
Tests that _setup_vector_store_index creates a PGVectorAccess object
and calls its load_index method.
"""
platform_table_name = "test_table"
dbname = "test_db"

base_engine = BaseEngine._setup_vector_store_index(
platform_table_name=platform_table_name,
dbname=dbname,
testing=True,
)
self.assertIn(dbname, base_engine.vector_store.connection_string)
self.assertEqual(base_engine.vector_store.table_name, platform_table_name)
15 changes: 15 additions & 0 deletions tests/unit/test_github_query_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from unittest import TestCase

from bot.retrievers.custom_retriever import CustomVectorStoreRetriever
from utils.query_engine import GitHubQueryEngine


class TestGitHubQueryEngine(TestCase):
def setUp(self) -> None:
community_id = "sample_community"
self.github_query_engine = GitHubQueryEngine(community_id)

def test_prepare_engine(self):
github_query_engine = self.github_query_engine.prepare(testing=True)
print(github_query_engine.__dict__)
self.assertIsInstance(github_query_engine.retriever, CustomVectorStoreRetriever)
1 change: 1 addition & 0 deletions utils/query_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# flake8: noqa
from .github import GitHubQueryEngine
from .prepare_discord_query_engine import prepare_discord_engine_auto_filter
from .subquery_gen_prompt import DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL
24 changes: 24 additions & 0 deletions utils/query_engine/base_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from llama_index.core import VectorStoreIndex
from tc_hivemind_backend.embeddings import CohereEmbedding
from tc_hivemind_backend.pg_vector_access import PGVectorAccess


class BaseEngine:
@classmethod
def _setup_vector_store_index(
cls,
platform_table_name: str,
dbname: str,
testing: bool = False,
) -> VectorStoreIndex:
"""
prepare the vector_store for querying data
"""
pg_vector = PGVectorAccess(
table_name=platform_table_name,
dbname=dbname,
testing=testing,
embed_model=CohereEmbedding(),
)
index = pg_vector.load_index()
return index
30 changes: 30 additions & 0 deletions utils/query_engine/github.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from bot.retrievers.custom_retriever import CustomVectorStoreRetriever
from bot.retrievers.utils.load_hyperparams import load_hyperparams
from llama_index.core import get_response_synthesizer
from llama_index.core.query_engine import RetrieverQueryEngine
from utils.query_engine.base_engine import BaseEngine


class GitHubQueryEngine(BaseEngine):
platform_name = "github"

def __init__(self, community_id: str) -> None:
dbname = f"community_{community_id}"
self.dbname = dbname

def prepare(self, testing=False):
index = self._setup_vector_store_index(
platform_table_name=self.platform_name,
dbname=self.dbname,
testing=testing,
)
_, similarity_top_k, _ = load_hyperparams()

retriever = CustomVectorStoreRetriever(
index=index, similarity_top_k=similarity_top_k
)
query_engine = RetrieverQueryEngine(
retriever=retriever,
response_synthesizer=get_response_synthesizer(),
)
return query_engine
24 changes: 2 additions & 22 deletions utils/query_engine/level_based_platform_query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import NodeWithScore
from llama_index.llms.openai import OpenAI
from tc_hivemind_backend.embeddings.cohere import CohereEmbedding
from tc_hivemind_backend.pg_vector_access import PGVectorAccess
from utils.query_engine.base_engine import BaseEngine
from utils.query_engine.level_based_platforms_util import LevelBasedPlatformUtils

qa_prompt = PromptTemplate(
Expand All @@ -29,7 +28,7 @@
)


class LevelBasedPlatformQueryEngine(CustomQueryEngine):
class LevelBasedPlatformQueryEngine(CustomQueryEngine, BaseEngine):
retriever: BaseRetriever
response_synthesizer: BaseSynthesizer
llm: OpenAI
Expand Down Expand Up @@ -335,22 +334,3 @@ def _prepare_context_str(
logging.debug(f"context_str of prompt\n" f"{context_str}")

return context_str

@classmethod
def _setup_vector_store_index(
cls,
platform_table_name: str,
dbname: str,
testing: bool = False,
) -> VectorStoreIndex:
"""
prepare the vector_store for querying data
"""
pg_vector = PGVectorAccess(
table_name=platform_table_name,
dbname=dbname,
testing=testing,
embed_model=CohereEmbedding(),
)
index = pg_vector.load_index()
return index
Loading