Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: subquery generator #10

Merged
merged 16 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
CHUNK_SIZE=
COHERE_API_KEY=
D_RETRIEVER_SEARCH=
EMBEDDING_DIM=
K1_RETRIEVER_SEARCH=
K2_RETRIEVER_SEARCH=
MONGODB_HOST=
MONGODB_PASS=
MONGODB_PORT=
MONGODB_USER=
NEO4J_DB=
NEO4J_HOST=
NEO4J_PASSWORD=
NEO4J_PORT=
NEO4J_PROTOCOL=
NEO4J_USER=
OPENAI_API_KEY=
POSTGRES_HOST=
POSTGRES_PASS=
POSTGRES_PORT=
POSTGRES_USER=
RABBIT_HOST=
RABBIT_PASSWORD=
RABBIT_PORT=
RABBIT_USER=
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,5 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

hivemind-bot-env/*
hivemind-bot-env/*
main.ipynb
10 changes: 7 additions & 3 deletions bot/retrievers/summary_retriever_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
the embedding model to use for doing embedding on the query string
default would be CohereEmbedding that we've written
"""
self.index = self._setup_index(table_name, dbname)
self.index = self._setup_index(table_name, dbname, embedding_model)
self.embedding_model = embedding_model

def get_similar_nodes(
Expand Down Expand Up @@ -62,10 +62,14 @@ def get_similar_nodes(

return nodes

def _setup_index(self, table_name: str, dbname: str) -> VectorStoreIndex:
def _setup_index(
self, table_name: str, dbname: str, embedding_model: BaseEmbedding
) -> VectorStoreIndex:
"""
setup the llama_index VectorStoreIndex
"""
pg_vector_access = PGVectorAccess(table_name=table_name, dbname=dbname)
pg_vector_access = PGVectorAccess(
table_name=table_name, dbname=dbname, embed_model=embedding_model
)
index = pg_vector_access.load_index()
return index
139 changes: 15 additions & 124 deletions discord_query.py
Original file line number Diff line number Diff line change
@@ -1,145 +1,36 @@
from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever
from bot.retrievers.process_dates import process_dates
from bot.retrievers.utils.load_hyperparams import load_hyperparams
from llama_index import QueryBundle
from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters
from llama_index.schema import NodeWithScore
from tc_hivemind_backend.embeddings.cohere import CohereEmbedding
from tc_hivemind_backend.pg_vector_access import PGVectorAccess
from utils.query_engine.discord_query_engine import prepare_discord_engine_auto_filter


def query_discord(
community_id: str,
query: str,
thread_names: list[str],
channel_names: list[str],
days: list[str],
similarity_top_k: int | None = None,
) -> str:
) -> tuple[str, list[NodeWithScore]]:
"""
query the discord database using filters given
and give an anwer to the given query using the LLM
query the llm using the query engine

Parameters
------------
guild_id : str
the discord guild data to query
query_engine : BaseQueryEngine
the prepared query engine
query : str
the query (question) of the user
thread_names : list[str]
the given threads to search for
channel_names : list[str]
the given channels to search for
days : list[str]
the given days to search for
similarity_top_k : int | None
the k similar results to use when querying the data
if `None` will load from `.env` file
the string question

Returns
---------
----------
response : str
the LLM response given the query
the LLM response
source_nodes : list[llama_index.schema.NodeWithScore]
the source nodes that helped in answering the question
"""
if similarity_top_k is None:
_, similarity_top_k, _ = load_hyperparams()

table_name = "discord"
dbname = f"community_{community_id}"

pg_vector = PGVectorAccess(table_name=table_name, dbname=dbname)

index = pg_vector.load_index()

thread_filters: list[ExactMatchFilter] = []
channel_filters: list[ExactMatchFilter] = []
day_filters: list[ExactMatchFilter] = []

for channel in channel_names:
channel_updated = channel.replace("'", "''")
channel_filters.append(ExactMatchFilter(key="channel", value=channel_updated))

for thread in thread_names:
thread_updated = thread.replace("'", "''")
thread_filters.append(ExactMatchFilter(key="thread", value=thread_updated))

for day in days:
day_filters.append(ExactMatchFilter(key="date", value=day))

all_filters: list[ExactMatchFilter] = []
all_filters.extend(thread_filters)
all_filters.extend(channel_filters)
all_filters.extend(day_filters)

filters = MetadataFilters(filters=all_filters, condition=FilterCondition.OR)

query_engine = index.as_query_engine(
filters=filters, similarity_top_k=similarity_top_k
query_engine = prepare_discord_engine_auto_filter(
community_id=community_id,
query=query,
)

query_bundle = QueryBundle(
query_str=query, embedding=CohereEmbedding().get_text_embedding(text=query)
)
response = query_engine.query(query_bundle)

return response.response


def query_discord_auto_filter(
community_id: str,
query: str,
similarity_top_k: int | None = None,
d: int | None = None,
) -> str:
"""
get the query results and do the filtering automatically.
By automatically we mean, it would first query the summaries
to get the metadata filters

Parameters
-----------
guild_id : str
the discord guild data to query
query : str
the query (question) of the user
similarity_top_k : int | None
the value for the initial summary search
to get the `k2` count simliar nodes
if `None`, then would read from `.env`
d : int
this would make the secondary search (`query_discord`)
to be done on the `metadata.date - d` to `metadata.date + d`


Returns
---------
response : str
the LLM response given the query
"""
table_name = "discord_summary"
dbname = f"community_{community_id}"

if d is None:
_, _, d = load_hyperparams()
if similarity_top_k is None:
similarity_top_k, _, _ = load_hyperparams()

discord_retriever = ForumBasedSummaryRetriever(table_name=table_name, dbname=dbname)

channels, threads, dates = discord_retriever.retreive_metadata(
query=query,
metadata_group1_key="channel",
metadata_group2_key="thread",
metadata_date_key="date",
similarity_top_k=similarity_top_k,
)

dates_modified = process_dates(list(dates), d)

response = query_discord(
community_id=community_id,
query=query,
thread_names=list(threads),
channel_names=list(channels),
days=dates_modified,
)
return response
return response.response, response.source_nodes
2 changes: 2 additions & 0 deletions docker-compose.test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ services:
- K1_RETRIEVER_SEARCH=20
- K2_RETRIEVER_SEARCH=5
- D_RETRIEVER_SEARCH=7
- COHERE_API_KEY=some_credentials
- OPENAI_API_KEY=some_credentials2
volumes:
- ./coverage:/project/coverage
depends_on:
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ neo4j>=5.14.1, <6.0.0
coverage>=7.3.3, <8.0.0
pytest>=7.4.3, <8.0.0
python-dotenv==1.0.0
tc-hivemind-backend==1.0.0
tc-hivemind-backend==1.1.0
celery>=5.3.6, <6.0.0
guidance
110 changes: 110 additions & 0 deletions subquery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from guidance.models import OpenAI as GuidanceOpenAI
from llama_index import QueryBundle, ServiceContext
from llama_index.core import BaseQueryEngine
from llama_index.query_engine import SubQuestionQueryEngine
from llama_index.question_gen.guidance_generator import GuidanceQuestionGenerator
from llama_index.schema import NodeWithScore
from llama_index.tools import QueryEngineTool, ToolMetadata
from tc_hivemind_backend.embeddings.cohere import CohereEmbedding
from utils.query_engine import prepare_discord_engine_auto_filter


def query_multiple_source(
query: str,
community_id: str,
discord: bool,
discourse: bool,
gdrive: bool,
notion: bool,
telegram: bool,
github: bool,
) -> tuple[str, list[NodeWithScore]]:
"""
query multiple platforms and get an answer from the multiple

Parameters
------------
query : str
the user question
community_id : str
the community id to get their data
discord : bool
if `True` then add the engine to the subquery_generator
discourse : bool
if `True` then add the engine to the subquery_generator
gdrive : bool
if `True` then add the engine to the subquery_generator
notion : bool
if `True` then add the engine to the subquery_generator
telegram : bool
if `True` then add the engine to the subquery_generator
github : bool
if `True` then add the engine to the subquery_generator


Returns
--------
response : str,
the response to the user query from the LLM
using the engines of the given platforms (pltform equal to True)
source_nodes : list[NodeWithScore]
the list of nodes that were source of answering
"""
query_engine_tools: list[QueryEngineTool] = []
tools: list[ToolMetadata] = []

discord_query_engine: BaseQueryEngine
# discourse_query_engine: BaseQueryEngine
# gdrive_query_engine: BaseQueryEngine
# notion_query_engine: BaseQueryEngine
# telegram_query_engine: BaseQueryEngine
# github_query_engine: BaseQueryEngine

# query engine perparation
# tools_metadata and query_engine_tools
if discord:
discord_query_engine = prepare_discord_engine_auto_filter(
community_id,
query,
similarity_top_k=None,
d=None,
)
tool_metadata = ToolMetadata(
name="Discord",
description="Contains messages and summaries of conversations from the Discord platform of the community",
)

tools.append(tool_metadata)
query_engine_tools.append(
QueryEngineTool(
query_engine=discord_query_engine,
metadata=tool_metadata,
)
)

if discourse:
raise NotImplementedError
if gdrive:
raise NotImplementedError
if notion:
raise NotImplementedError
if telegram:
raise NotImplementedError
if github:
raise NotImplementedError

question_gen = GuidanceQuestionGenerator.from_defaults(
guidance_llm=GuidanceOpenAI("text-davinci-003"), verbose=False
)
embed_model = CohereEmbedding()
service_context = ServiceContext.from_defaults(embed_model=embed_model)
s_engine = SubQuestionQueryEngine.from_defaults(
question_gen=question_gen,
query_engine_tools=query_engine_tools,
use_async=False,
service_context=service_context,
)
query_embedding = embed_model.get_text_embedding(text=query)
response = s_engine.query(QueryBundle(query_str=query, embedding=query_embedding))

return response.response, response.source_nodes
Loading