Skip to content

Commit

Permalink
Merge pull request #99 from TogetherCrew/feat/90-qdrant-summary-and-r…
Browse files Browse the repository at this point in the history
…aw-engine

Feat/90 qdrant summary and raw engine
  • Loading branch information
amindadgar authored Nov 14, 2024
2 parents c5056be + d924922 commit b444891
Show file tree
Hide file tree
Showing 8 changed files with 403 additions and 81 deletions.
8 changes: 8 additions & 0 deletions schema/type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from enum import Enum


class DataType(Enum):
INTEGER = "INTEGER"
STRING = "STRING"
BOOLEAN = "BOOLEAN"
FLOAT = "FLOAT"
21 changes: 11 additions & 10 deletions subquery.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from guidance.models import OpenAIChat
from llama_index.core import QueryBundle, Settings
from llama_index.core.base.base_query_engine import BaseQueryEngine
from llama_index.core.query_engine import SubQuestionQueryEngine
from llama_index.core.schema import NodeWithScore
from llama_index.core.tools import QueryEngineTool, ToolMetadata
Expand All @@ -14,6 +13,7 @@
GitHubQueryEngine,
MediaWikiQueryEngine,
NotionQueryEngine,
TelegramDualQueryEngine,
TelegramQueryEngine,
prepare_discord_engine_auto_filter,
)
Expand Down Expand Up @@ -71,14 +71,6 @@ def query_multiple_source(
tools: list[ToolMetadata] = []
qdrant_utils = QDrantUtils(community_id)

discord_query_engine: BaseQueryEngine
github_query_engine: BaseQueryEngine
# discourse_query_engine: BaseQueryEngine
google_query_engine: BaseQueryEngine
notion_query_engine: BaseQueryEngine
mediawiki_query_engine: BaseQueryEngine
# telegram_query_engine: BaseQueryEngine

# wrapper for more clarity
check_collection = qdrant_utils.check_collection_exist

Expand Down Expand Up @@ -134,7 +126,16 @@ def query_multiple_source(
)
)
if telegram and check_collection("telegram"):
telegram_query_engine = TelegramQueryEngine(community_id=community_id).prepare()
# checking if the summaries was available
if check_collection("telegram_summary"):
telegram_query_engine = TelegramDualQueryEngine(
community_id=community_id
).prepare()
else:
telegram_query_engine = TelegramQueryEngine(
community_id=community_id
).prepare()

tool_metadata = ToolMetadata(
name="Telegram",
description=(
Expand Down
25 changes: 0 additions & 25 deletions tests/unit/test_base_qdrant_engine.py

This file was deleted.

3 changes: 2 additions & 1 deletion utils/query_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# flake8: noqa
from .dual_qdrant_retrieval_engine import DualQdrantRetrievalEngine
from .gdrive import GDriveQueryEngine
from .github import GitHubQueryEngine
from .media_wiki import MediaWikiQueryEngine
from .notion import NotionQueryEngine
from .prepare_discord_query_engine import prepare_discord_engine_auto_filter
from .subquery_gen_prompt import DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL
from .telegram import TelegramQueryEngine
from .telegram import TelegramDualQueryEngine, TelegramQueryEngine
56 changes: 11 additions & 45 deletions utils/query_engine/base_qdrant_engine.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from bot.retrievers.utils.load_hyperparams import load_hyperparams
from llama_index.core import VectorStoreIndex, get_response_synthesizer
from llama_index.core.indices.vector_store.retrievers.retriever import (
VectorIndexRetriever,
)
from llama_index.core.query_engine import RetrieverQueryEngine
from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess
from llama_index.core import Settings, get_response_synthesizer
from llama_index.core.query_engine import BaseQueryEngine

from .dual_qdrant_retrieval_engine import DualQdrantRetrievalEngine


class BaseQdrantEngine:
Expand All @@ -24,44 +21,13 @@ def __init__(self, platform_name: str, community_id: str) -> None:
"""
self.platform_name = platform_name
self.community_id = community_id
self.collection_name = f"{self.community_id}_{platform_name}"

def prepare(self, testing=False):
vector_store_index = self._setup_vector_store_index(
testing=testing,
def prepare(self, testing=False) -> BaseQueryEngine:
engine = DualQdrantRetrievalEngine.setup_engine(
llm=Settings.llm,
synthesizer=get_response_synthesizer(),
platform_name=self.platform_name,
community_id=self.community_id,
)
_, similarity_top_k, _ = load_hyperparams()

retriever = VectorIndexRetriever(
index=vector_store_index,
similarity_top_k=similarity_top_k,
)
query_engine = RetrieverQueryEngine(
retriever=retriever,
response_synthesizer=get_response_synthesizer(),
)
return query_engine

def _setup_vector_store_index(
self,
testing: bool = False,
**kwargs,
) -> VectorStoreIndex:
"""
prepare the vector_store for querying data
Parameters
------------
testing : bool
for testing purposes
**kwargs :
collection_name : str
to override the default collection_name
"""
collection_name = kwargs.get("collection_name", self.collection_name)
qdrant_vector = QDrantVectorAccess(
collection_name=collection_name,
testing=testing,
)
index = qdrant_vector.load_index()
return index
return engine
210 changes: 210 additions & 0 deletions utils/query_engine/dual_qdrant_retrieval_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
from bot.retrievers.utils.load_hyperparams import load_hyperparams
from llama_index.core import PromptTemplate, VectorStoreIndex
from llama_index.core.indices.vector_store.retrievers.retriever import (
VectorIndexRetriever,
)
from llama_index.core.query_engine import CustomQueryEngine
from llama_index.core.response_synthesizers import BaseSynthesizer
from llama_index.core.retrievers import BaseRetriever
from llama_index.llms.openai import OpenAI
from schema.type import DataType
from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess
from utils.query_engine.qdrant_query_engine_utils import QdrantEngineUtils

qa_prompt = PromptTemplate(
"Context information is below.\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"Given the context information and not prior knowledge, "
"answer the query.\n"
"Query: {query_str}\n"
"Answer: "
)


class DualQdrantRetrievalEngine(CustomQueryEngine):
"""RAG String Query Engine."""

retriever: BaseRetriever
response_synthesizer: BaseSynthesizer
llm: OpenAI
qa_prompt: PromptTemplate

def custom_query(self, query_str: str):
if self.summary_retriever is None:
response = self._process_basic_query(query_str)
else:
response = self._process_summary_query(query_str)
return str(response)

@classmethod
def setup_engine(
cls,
llm: OpenAI,
synthesizer: BaseSynthesizer,
platform_name: str,
community_id: str,
):
"""
setup the custom query engine on qdrant data
Parameters
------------
llm : OpenAI
the llm to be used for RAG pipeline
synthesizer : BaseSynthesizer
the process of generating response using an LLM
qa_prompt : PromptTemplate
the prompt template to be filled and passed to an LLM
platform_name : str
specifying the platform data to identify the data collection
community_id : str
specifying community_id to identify the data collection
"""
collection_name = f"{community_id}_{platform_name}"

_, raw_data_top_k, date_margin = load_hyperparams()
cls._date_margin = date_margin

cls._vector_store_index: VectorStoreIndex = cls._setup_vector_store_index(
collection_name=collection_name
)
retriever = VectorIndexRetriever(
index=cls._vector_store_index,
similarity_top_k=raw_data_top_k,
)

cls.summary_retriever = None

return cls(
retriever=retriever,
response_synthesizer=synthesizer,
llm=llm,
qa_prompt=qa_prompt,
)

@classmethod
def setup_engine_with_summaries(
cls,
llm: OpenAI,
synthesizer: BaseSynthesizer,
platform_name: str,
community_id: str,
metadata_date_key: str,
metadata_date_format: DataType,
metadata_date_summary_key: str,
metadata_date_summary_format: DataType,
):
"""
setup the custom query engine on qdrant data
Parameters
------------
use_summary : bool
whether to use the summary data or not
note: the summary data should be available before
for this option to be enabled
llm : OpenAI
the llm to be used for RAG pipeline
synthesizer : BaseSynthesizer
the process of generating response using an LLM
platform_name : str
specifying the platform data to identify the data collection
community_id : str
specifying community_id to identify the data collection
metadata_date_summary_key : str | None
the date key name in summary documents' metadata
In case of `use_summary` equal to be true this shuold be passed
metadata_date_summary_format : DataType | None
the date format in metadata
In case of `use_summary` equal to be true this shuold be passed
NOTE: this should be always a string for the filtering of it to work.
"""
collection_name = f"{community_id}_{platform_name}"
summary_data_top_k, raw_data_top_k, date_margin = load_hyperparams()
cls._date_margin = date_margin
cls._raw_data_top_k = raw_data_top_k

cls._vector_store_index: VectorStoreIndex = cls._setup_vector_store_index(
collection_name=collection_name
)
retriever = VectorIndexRetriever(
index=cls._vector_store_index,
similarity_top_k=raw_data_top_k,
)

summary_collection_name = collection_name + "_summary"
summary_vector_store_index = cls._setup_vector_store_index(
collection_name=summary_collection_name
)

cls.summary_retriever = VectorIndexRetriever(
index=summary_vector_store_index,
similarity_top_k=summary_data_top_k,
)
cls.metadata_date_summary_key = metadata_date_summary_key
cls.metadata_date_summary_format = metadata_date_summary_format
cls.metadata_date_key = metadata_date_key
cls.metadata_date_format = metadata_date_format

return cls(
retriever=retriever,
response_synthesizer=synthesizer,
llm=llm,
qa_prompt=qa_prompt,
)

@classmethod
def _setup_vector_store_index(
cls,
collection_name: str,
) -> VectorStoreIndex:
"""
prepare the vector_store for querying data
Parameters
------------
collection_name : str
to override the default collection_name
"""
qdrant_vector = QDrantVectorAccess(collection_name=collection_name)
index = qdrant_vector.load_index()
return index

def _process_basic_query(self, query_str: str) -> str:
nodes = self.retriever.retrieve(query_str)
context_str = "\n\n".join([n.node.get_content() for n in nodes])
prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str)
response = self.llm.complete(prompt)
return response

def _process_summary_query(self, query_str: str) -> str:
summary_nodes = self.summary_retriever.retrieve(query_str)
utils = QdrantEngineUtils(
metadata_date_key=self.metadata_date_key,
metadata_date_format=self.metadata_date_format,
date_margin=self._date_margin,
)

dates = [
node.metadata[self.metadata_date_summary_key]
for node in summary_nodes
if self.metadata_date_summary_key in node.metadata
]

if not dates:
return self._process_basic_query(query_str)

filter = utils.define_raw_data_filters(dates=dates)

retriever: BaseRetriever = self._vector_store_index.as_retriever(
vector_store_kwargs={"qdrant_filters": filter},
similarity_top_k=self._raw_data_top_k,
)
raw_nodes = retriever.retrieve(query_str)

context_str = utils.combine_nodes_for_prompt(summary_nodes, raw_nodes)
prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str)
response = self.llm.complete(prompt)
return response
Loading

0 comments on commit b444891

Please sign in to comment.