Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/90 qdrant summary and raw engine #99

Merged
merged 10 commits into from
Nov 14, 2024
8 changes: 8 additions & 0 deletions schema/type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from enum import Enum


class DataType(Enum):
INTEGER = "INTEGER"
STRING = "STRING"
BOOLEAN = "BOOLEAN"
FLOAT = "FLOAT"
21 changes: 11 additions & 10 deletions subquery.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from guidance.models import OpenAIChat
from llama_index.core import QueryBundle, Settings
from llama_index.core.base.base_query_engine import BaseQueryEngine
from llama_index.core.query_engine import SubQuestionQueryEngine
from llama_index.core.schema import NodeWithScore
from llama_index.core.tools import QueryEngineTool, ToolMetadata
Expand All @@ -14,6 +13,7 @@
GitHubQueryEngine,
MediaWikiQueryEngine,
NotionQueryEngine,
TelegramDualQueryEngine,
TelegramQueryEngine,
prepare_discord_engine_auto_filter,
)
Expand Down Expand Up @@ -71,14 +71,6 @@ def query_multiple_source(
tools: list[ToolMetadata] = []
qdrant_utils = QDrantUtils(community_id)

discord_query_engine: BaseQueryEngine
github_query_engine: BaseQueryEngine
# discourse_query_engine: BaseQueryEngine
google_query_engine: BaseQueryEngine
notion_query_engine: BaseQueryEngine
mediawiki_query_engine: BaseQueryEngine
# telegram_query_engine: BaseQueryEngine

# wrapper for more clarity
check_collection = qdrant_utils.check_collection_exist

Expand Down Expand Up @@ -134,7 +126,16 @@ def query_multiple_source(
)
)
if telegram and check_collection("telegram"):
telegram_query_engine = TelegramQueryEngine(community_id=community_id).prepare()
# checking if the summaries was available
if check_collection("telegram_summary"):
telegram_query_engine = TelegramDualQueryEngine(
community_id=community_id
).prepare()
else:
telegram_query_engine = TelegramQueryEngine(
community_id=community_id
).prepare()

amindadgar marked this conversation as resolved.
Show resolved Hide resolved
tool_metadata = ToolMetadata(
name="Telegram",
description=(
Expand Down
25 changes: 0 additions & 25 deletions tests/unit/test_base_qdrant_engine.py

This file was deleted.

3 changes: 2 additions & 1 deletion utils/query_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# flake8: noqa
from .dual_qdrant_retrieval_engine import DualQdrantRetrievalEngine
from .gdrive import GDriveQueryEngine
from .github import GitHubQueryEngine
from .media_wiki import MediaWikiQueryEngine
from .notion import NotionQueryEngine
from .prepare_discord_query_engine import prepare_discord_engine_auto_filter
from .subquery_gen_prompt import DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL
from .telegram import TelegramQueryEngine
from .telegram import TelegramDualQueryEngine, TelegramQueryEngine
56 changes: 11 additions & 45 deletions utils/query_engine/base_qdrant_engine.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from bot.retrievers.utils.load_hyperparams import load_hyperparams
from llama_index.core import VectorStoreIndex, get_response_synthesizer
from llama_index.core.indices.vector_store.retrievers.retriever import (
VectorIndexRetriever,
)
from llama_index.core.query_engine import RetrieverQueryEngine
from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess
from llama_index.core import Settings, get_response_synthesizer
from llama_index.core.query_engine import BaseQueryEngine

from .dual_qdrant_retrieval_engine import DualQdrantRetrievalEngine


class BaseQdrantEngine:
Expand All @@ -24,44 +21,13 @@ def __init__(self, platform_name: str, community_id: str) -> None:
"""
self.platform_name = platform_name
self.community_id = community_id
self.collection_name = f"{self.community_id}_{platform_name}"

def prepare(self, testing=False):
vector_store_index = self._setup_vector_store_index(
testing=testing,
def prepare(self, testing=False) -> BaseQueryEngine:
engine = DualQdrantRetrievalEngine.setup_engine(
llm=Settings.llm,
synthesizer=get_response_synthesizer(),
platform_name=self.platform_name,
community_id=self.community_id,
)
_, similarity_top_k, _ = load_hyperparams()

retriever = VectorIndexRetriever(
index=vector_store_index,
similarity_top_k=similarity_top_k,
)
query_engine = RetrieverQueryEngine(
retriever=retriever,
response_synthesizer=get_response_synthesizer(),
)
return query_engine

def _setup_vector_store_index(
self,
testing: bool = False,
**kwargs,
) -> VectorStoreIndex:
"""
prepare the vector_store for querying data

Parameters
------------
testing : bool
for testing purposes
**kwargs :
collection_name : str
to override the default collection_name
"""
collection_name = kwargs.get("collection_name", self.collection_name)
qdrant_vector = QDrantVectorAccess(
collection_name=collection_name,
testing=testing,
)
index = qdrant_vector.load_index()
return index
return engine
210 changes: 210 additions & 0 deletions utils/query_engine/dual_qdrant_retrieval_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
from bot.retrievers.utils.load_hyperparams import load_hyperparams
from llama_index.core import PromptTemplate, VectorStoreIndex
from llama_index.core.indices.vector_store.retrievers.retriever import (
VectorIndexRetriever,
)
from llama_index.core.query_engine import CustomQueryEngine
from llama_index.core.response_synthesizers import BaseSynthesizer
from llama_index.core.retrievers import BaseRetriever
from llama_index.llms.openai import OpenAI
from schema.type import DataType
from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess
from utils.query_engine.qdrant_query_engine_utils import QdrantEngineUtils

qa_prompt = PromptTemplate(
"Context information is below.\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"Given the context information and not prior knowledge, "
"answer the query.\n"
"Query: {query_str}\n"
"Answer: "
)


class DualQdrantRetrievalEngine(CustomQueryEngine):
"""RAG String Query Engine."""

retriever: BaseRetriever
response_synthesizer: BaseSynthesizer
llm: OpenAI
qa_prompt: PromptTemplate

def custom_query(self, query_str: str):
if self.summary_retriever is None:
response = self._process_basic_query(query_str)
else:
response = self._process_summary_query(query_str)
return str(response)

@classmethod
def setup_engine(
cls,
llm: OpenAI,
synthesizer: BaseSynthesizer,
platform_name: str,
community_id: str,
):
"""
setup the custom query engine on qdrant data

Parameters
------------
llm : OpenAI
the llm to be used for RAG pipeline
synthesizer : BaseSynthesizer
the process of generating response using an LLM
qa_prompt : PromptTemplate
the prompt template to be filled and passed to an LLM
platform_name : str
specifying the platform data to identify the data collection
community_id : str
specifying community_id to identify the data collection
"""
collection_name = f"{community_id}_{platform_name}"

_, raw_data_top_k, date_margin = load_hyperparams()
cls._date_margin = date_margin

cls._vector_store_index: VectorStoreIndex = cls._setup_vector_store_index(
collection_name=collection_name
)
retriever = VectorIndexRetriever(
index=cls._vector_store_index,
similarity_top_k=raw_data_top_k,
)

cls.summary_retriever = None

return cls(
retriever=retriever,
response_synthesizer=synthesizer,
llm=llm,
qa_prompt=qa_prompt,
)

amindadgar marked this conversation as resolved.
Show resolved Hide resolved
@classmethod
def setup_engine_with_summaries(
cls,
llm: OpenAI,
synthesizer: BaseSynthesizer,
platform_name: str,
community_id: str,
metadata_date_key: str,
metadata_date_format: DataType,
metadata_date_summary_key: str,
metadata_date_summary_format: DataType,
):
"""
setup the custom query engine on qdrant data

Parameters
------------
use_summary : bool
whether to use the summary data or not
note: the summary data should be available before
for this option to be enabled
llm : OpenAI
the llm to be used for RAG pipeline
synthesizer : BaseSynthesizer
the process of generating response using an LLM
platform_name : str
specifying the platform data to identify the data collection
community_id : str
specifying community_id to identify the data collection
metadata_date_summary_key : str | None
the date key name in summary documents' metadata
In case of `use_summary` equal to be true this shuold be passed
metadata_date_summary_format : DataType | None
the date format in metadata
In case of `use_summary` equal to be true this shuold be passed
NOTE: this should be always a string for the filtering of it to work.
"""
collection_name = f"{community_id}_{platform_name}"
summary_data_top_k, raw_data_top_k, date_margin = load_hyperparams()
cls._date_margin = date_margin
cls._raw_data_top_k = raw_data_top_k

cls._vector_store_index: VectorStoreIndex = cls._setup_vector_store_index(
collection_name=collection_name
)
retriever = VectorIndexRetriever(
index=cls._vector_store_index,
similarity_top_k=raw_data_top_k,
)

summary_collection_name = collection_name + "_summary"
summary_vector_store_index = cls._setup_vector_store_index(
collection_name=summary_collection_name
)

cls.summary_retriever = VectorIndexRetriever(
index=summary_vector_store_index,
similarity_top_k=summary_data_top_k,
)
cls.metadata_date_summary_key = metadata_date_summary_key
cls.metadata_date_summary_format = metadata_date_summary_format
cls.metadata_date_key = metadata_date_key
cls.metadata_date_format = metadata_date_format

return cls(
retriever=retriever,
response_synthesizer=synthesizer,
llm=llm,
qa_prompt=qa_prompt,
)

@classmethod
def _setup_vector_store_index(
cls,
collection_name: str,
) -> VectorStoreIndex:
"""
prepare the vector_store for querying data

Parameters
------------
collection_name : str
to override the default collection_name
"""
qdrant_vector = QDrantVectorAccess(collection_name=collection_name)
index = qdrant_vector.load_index()
return index
amindadgar marked this conversation as resolved.
Show resolved Hide resolved

def _process_basic_query(self, query_str: str) -> str:
nodes = self.retriever.retrieve(query_str)
context_str = "\n\n".join([n.node.get_content() for n in nodes])
prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str)
response = self.llm.complete(prompt)
return response

def _process_summary_query(self, query_str: str) -> str:
summary_nodes = self.summary_retriever.retrieve(query_str)
utils = QdrantEngineUtils(
metadata_date_key=self.metadata_date_key,
metadata_date_format=self.metadata_date_format,
date_margin=self._date_margin,
)

dates = [
node.metadata[self.metadata_date_summary_key]
for node in summary_nodes
if self.metadata_date_summary_key in node.metadata
]

if not dates:
return self._process_basic_query(query_str)

filter = utils.define_raw_data_filters(dates=dates)

retriever: BaseRetriever = self._vector_store_index.as_retriever(
vector_store_kwargs={"qdrant_filters": filter},
similarity_top_k=self._raw_data_top_k,
)
raw_nodes = retriever.retrieve(query_str)

context_str = utils.combine_nodes_for_prompt(summary_nodes, raw_nodes)
prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str)
response = self.llm.complete(prompt)
return response
Loading
Loading