Skip to content

Commit

Permalink
Merge pull request #50 from TogetherCrew/feat/github-rag
Browse files Browse the repository at this point in the history
feat: Added github query engine!
  • Loading branch information
amindadgar authored Apr 23, 2024
2 parents bc4fb96 + 25242db commit 21deca4
Show file tree
Hide file tree
Showing 9 changed files with 169 additions and 24 deletions.
59 changes: 59 additions & 0 deletions bot/retrievers/custom_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
We're going to override the `_build_node_list_from_query_result`
since it is raising errors having the llama-index legacy & newer version together
"""

from llama_index.core.data_structs.data_structs import IndexDict
from llama_index.core.indices.utils import log_vector_store_query_result
from llama_index.core.indices.vector_store.retrievers.retriever import (
VectorIndexRetriever,
)
from llama_index.core.schema import Node, NodeWithScore, ObjectType
from llama_index.core.vector_stores.types import VectorStoreQueryResult


class CustomVectorStoreRetriever(VectorIndexRetriever):
def _build_node_list_from_query_result(
self, query_result: VectorStoreQueryResult
) -> list[NodeWithScore]:
if query_result.nodes is None:
# NOTE: vector store does not keep text and returns node indices.
# Need to recover all nodes from docstore
if query_result.ids is None:
raise ValueError(
"Vector store query result should return at "
"least one of nodes or ids."
)
assert isinstance(self._index.index_struct, IndexDict)
node_ids = [
self._index.index_struct.nodes_dict[idx] for idx in query_result.ids
]
nodes = self._docstore.get_nodes(node_ids)
query_result.nodes = nodes
else:
# NOTE: vector store keeps text, returns nodes.
# Only need to recover image or index nodes from docstore
for i in range(len(query_result.nodes)):
source_node = query_result.nodes[i].source_node
if (not self._vector_store.stores_text) or (
source_node is not None and source_node.node_type != ObjectType.TEXT
):
node_id = query_result.nodes[i].node_id
if self._docstore.document_exists(node_id):
query_result.nodes[i] = self._docstore.get_node(
node_id
) # type: ignore[index]

log_vector_store_query_result(query_result)
node_with_scores: list[NodeWithScore] = []
for ind, node in enumerate(query_result.nodes):
score: float | None = None
if query_result.similarities is not None:
score = query_result.similarities[ind]
# This is the part we updated
node_new = Node.from_dict(node.to_dict())
node_with_score = NodeWithScore(node=node_new, score=score)

node_with_scores.append(node_with_score)

return node_with_scores
1 change: 1 addition & 0 deletions celery_app/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def ask_question_auto_search(
query=question,
community_id=community_id,
discord=True,
github=True,
)

# source_nodes_dict: list[dict[str, Any]] = []
Expand Down
18 changes: 16 additions & 2 deletions subquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tc_hivemind_backend.embeddings.cohere import CohereEmbedding
from utils.query_engine import (
DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL,
GitHubQueryEngine,
prepare_discord_engine_auto_filter,
)

Expand Down Expand Up @@ -64,11 +65,11 @@ def query_multiple_source(
tools: list[ToolMetadata] = []

discord_query_engine: BaseQueryEngine
github_query_engine: BaseQueryEngine
# discourse_query_engine: BaseQueryEngine
# gdrive_query_engine: BaseQueryEngine
# notion_query_engine: BaseQueryEngine
# telegram_query_engine: BaseQueryEngine
# github_query_engine: BaseQueryEngine

# query engine perparation
# tools_metadata and query_engine_tools
Expand Down Expand Up @@ -99,7 +100,20 @@ def query_multiple_source(
if telegram:
raise NotImplementedError
if github:
raise NotImplementedError
github_query_engine = GitHubQueryEngine(community_id=community_id).prepare()
tool_metadata = ToolMetadata(
name="GitHub",
description=(
"Hosts commits and conversations from Github issues and"
" pull requests from the selected repositories"
),
)
query_engine_tools.append(
QueryEngineTool(
query_engine=github_query_engine,
metadata=tool_metadata,
)
)

embed_model = CohereEmbedding()
llm = OpenAI("gpt-3.5-turbo")
Expand Down
21 changes: 21 additions & 0 deletions tests/unit/test_base_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from unittest import TestCase

from utils.query_engine.base_engine import BaseEngine


class TestBaseEngine(TestCase):
def test_setup_vector_store_index(self):
"""
Tests that _setup_vector_store_index creates a PGVectorAccess object
and calls its load_index method.
"""
platform_table_name = "test_table"
dbname = "test_db"

base_engine = BaseEngine._setup_vector_store_index(
platform_table_name=platform_table_name,
dbname=dbname,
testing=True,
)
self.assertIn(dbname, base_engine.vector_store.connection_string)
self.assertEqual(base_engine.vector_store.table_name, platform_table_name)
15 changes: 15 additions & 0 deletions tests/unit/test_github_query_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from unittest import TestCase

from bot.retrievers.custom_retriever import CustomVectorStoreRetriever
from utils.query_engine import GitHubQueryEngine


class TestGitHubQueryEngine(TestCase):
def setUp(self) -> None:
community_id = "sample_community"
self.github_query_engine = GitHubQueryEngine(community_id)

def test_prepare_engine(self):
github_query_engine = self.github_query_engine.prepare(testing=True)
print(github_query_engine.__dict__)
self.assertIsInstance(github_query_engine.retriever, CustomVectorStoreRetriever)
1 change: 1 addition & 0 deletions utils/query_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# flake8: noqa
from .github import GitHubQueryEngine
from .prepare_discord_query_engine import prepare_discord_engine_auto_filter
from .subquery_gen_prompt import DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL
24 changes: 24 additions & 0 deletions utils/query_engine/base_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from llama_index.core import VectorStoreIndex
from tc_hivemind_backend.embeddings import CohereEmbedding
from tc_hivemind_backend.pg_vector_access import PGVectorAccess


class BaseEngine:
@classmethod
def _setup_vector_store_index(
cls,
platform_table_name: str,
dbname: str,
testing: bool = False,
) -> VectorStoreIndex:
"""
prepare the vector_store for querying data
"""
pg_vector = PGVectorAccess(
table_name=platform_table_name,
dbname=dbname,
testing=testing,
embed_model=CohereEmbedding(),
)
index = pg_vector.load_index()
return index
30 changes: 30 additions & 0 deletions utils/query_engine/github.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from bot.retrievers.custom_retriever import CustomVectorStoreRetriever
from bot.retrievers.utils.load_hyperparams import load_hyperparams
from llama_index.core import get_response_synthesizer
from llama_index.core.query_engine import RetrieverQueryEngine
from utils.query_engine.base_engine import BaseEngine


class GitHubQueryEngine(BaseEngine):
platform_name = "github"

def __init__(self, community_id: str) -> None:
dbname = f"community_{community_id}"
self.dbname = dbname

def prepare(self, testing=False):
index = self._setup_vector_store_index(
platform_table_name=self.platform_name,
dbname=self.dbname,
testing=testing,
)
_, similarity_top_k, _ = load_hyperparams()

retriever = CustomVectorStoreRetriever(
index=index, similarity_top_k=similarity_top_k
)
query_engine = RetrieverQueryEngine(
retriever=retriever,
response_synthesizer=get_response_synthesizer(),
)
return query_engine
24 changes: 2 additions & 22 deletions utils/query_engine/level_based_platform_query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import NodeWithScore
from llama_index.llms.openai import OpenAI
from tc_hivemind_backend.embeddings.cohere import CohereEmbedding
from tc_hivemind_backend.pg_vector_access import PGVectorAccess
from utils.query_engine.base_engine import BaseEngine
from utils.query_engine.level_based_platforms_util import LevelBasedPlatformUtils

qa_prompt = PromptTemplate(
Expand All @@ -29,7 +28,7 @@
)


class LevelBasedPlatformQueryEngine(CustomQueryEngine):
class LevelBasedPlatformQueryEngine(CustomQueryEngine, BaseEngine):
retriever: BaseRetriever
response_synthesizer: BaseSynthesizer
llm: OpenAI
Expand Down Expand Up @@ -335,22 +334,3 @@ def _prepare_context_str(
logging.debug(f"context_str of prompt\n" f"{context_str}")

return context_str

@classmethod
def _setup_vector_store_index(
cls,
platform_table_name: str,
dbname: str,
testing: bool = False,
) -> VectorStoreIndex:
"""
prepare the vector_store for querying data
"""
pg_vector = PGVectorAccess(
table_name=platform_table_name,
dbname=dbname,
testing=testing,
embed_model=CohereEmbedding(),
)
index = pg_vector.load_index()
return index

0 comments on commit 21deca4

Please sign in to comment.