Skip to content

Commit

Permalink
fix: We still needed the PGVector engine for discord and discourse!
Browse files Browse the repository at this point in the history
amindadgar committed May 20, 2024

Verified

This commit was signed with the committer’s verified signature.
amindadgar Mohammad Amin Dadgar
1 parent b1f15f4 commit 2533f36
Showing 8 changed files with 95 additions and 15 deletions.
4 changes: 2 additions & 2 deletions tests/unit/test_base_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from unittest import TestCase

from utils.query_engine.base_engine import BaseEngine
from utils.query_engine.base_qdrant_engine import BaseQdrantEngine


class TestBaseEngine(TestCase):
@@ -11,7 +11,7 @@ def test_setup_vector_store_index(self):
"""
platform_table_name = "test_table"
community_id = "123456"
base_engine = BaseEngine(
base_engine = BaseQdrantEngine(
platform_name=platform_table_name,
community_id=community_id,
)
80 changes: 80 additions & 0 deletions utils/query_engine/base_pg_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from bot.retrievers.custom_retriever import CustomVectorStoreRetriever
from bot.retrievers.utils.load_hyperparams import load_hyperparams
from llama_index.core import VectorStoreIndex, get_response_synthesizer
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.query_engine import RetrieverQueryEngine
from tc_hivemind_backend.embeddings import CohereEmbedding
from tc_hivemind_backend.pg_vector_access import PGVectorAccess


class BasePGEngine:
def __init__(self, platform_name: str, community_id: str) -> None:
"""
initialize the pg vector db engine to query the database related to a community
and the table related to the platform
Parameters
-----------
platform_name : str
the table representative of data for a specific platform
community_id : str
the database for a community
normally the community database is saved as
`community_{community_id}`
"""
self.platform_name = platform_name
self.community_id = community_id
self.dbname = f"community_{self.community_id}"

def prepare(self, testing=False):
index = self._setup_vector_store_index(
testing=testing,
)
_, similarity_top_k, _ = load_hyperparams()

retriever = CustomVectorStoreRetriever(
index=index, similarity_top_k=similarity_top_k
)
query_engine = RetrieverQueryEngine(
retriever=retriever,
response_synthesizer=get_response_synthesizer(),
)
return query_engine

def _setup_vector_store_index(
self,
testing: bool = False,
**kwargs,
) -> VectorStoreIndex:
"""
prepare the vector_store for querying data
Parameters
------------
testing : bool
for testing purposes
**kwargs :
table_name : str
to override the default table_name
dbname : str
to override the default database name
"""
table_name = kwargs.get("table_name", self.platform_name)
dbname = kwargs.get("dbname", self.dbname)

embed_model: BaseEmbedding
if testing:
from llama_index.core import MockEmbedding

embed_model = MockEmbedding(embed_dim=1024)
else:
embed_model = CohereEmbedding()

pg_vector = PGVectorAccess(
table_name=table_name,
dbname=dbname,
testing=testing,
embed_model=embed_model,
)
index = pg_vector.load_index()
return index
Original file line number Diff line number Diff line change
@@ -7,10 +7,10 @@
from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess


class BaseEngine:
class BaseQdrantEngine:
def __init__(self, platform_name: str, community_id: str) -> None:
"""
initialize the engine to query the database related to a community
initialize the qdrant db engine to query the database related to a community
and the table related to the platform
Parameters
4 changes: 2 additions & 2 deletions utils/query_engine/gdrive.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from utils.query_engine.base_engine import BaseEngine
from utils.query_engine.base_qdrant_engine import BaseQdrantEngine


class GDriveQueryEngine(BaseEngine):
class GDriveQueryEngine(BaseQdrantEngine):
def __init__(self, community_id: str) -> None:
platform_name = "gdrive"
super().__init__(platform_name, community_id)
4 changes: 2 additions & 2 deletions utils/query_engine/github.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from utils.query_engine.base_engine import BaseEngine
from utils.query_engine.base_qdrant_engine import BaseQdrantEngine


class GitHubQueryEngine(BaseEngine):
class GitHubQueryEngine(BaseQdrantEngine):
def __init__(self, community_id: str) -> None:
platform_name = "github"
super().__init__(platform_name, community_id)
6 changes: 3 additions & 3 deletions utils/query_engine/level_based_platform_query_engine.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import NodeWithScore
from llama_index.llms.openai import OpenAI
from utils.query_engine.base_engine import BaseEngine
from utils.query_engine.base_pg_engine import BasePGEngine
from utils.query_engine.level_based_platforms_util import LevelBasedPlatformUtils

qa_prompt = PromptTemplate(
@@ -116,7 +116,7 @@ def prepare_platform_engine(
)
llm = kwargs.get("llm", OpenAI("gpt-4"))
qa_prompt_ = kwargs.get("qa_prompt", qa_prompt)
base_engine = BaseEngine(platform_table_name, community_id)
base_engine = BasePGEngine(platform_table_name, community_id)
index: VectorStoreIndex = kwargs.get(
"index_raw",
base_engine._setup_vector_store_index(
@@ -198,7 +198,7 @@ def prepare_engine_auto_filter(
dbname = f"community_{community_id}"
summary_similarity_top_k, _, d = load_hyperparams()

base_engine = BaseEngine(
base_engine = BasePGEngine(
platform_table_name + "_summary",
community_id,
)
4 changes: 2 additions & 2 deletions utils/query_engine/media_wiki.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from utils.query_engine.base_engine import BaseEngine
from utils.query_engine.base_qdrant_engine import BaseQdrantEngine


class MediaWikiQueryEngine(BaseEngine):
class MediaWikiQueryEngine(BaseQdrantEngine):
def __init__(self, community_id: str) -> None:
platform_name = "mediawiki"
super().__init__(platform_name, community_id)
4 changes: 2 additions & 2 deletions utils/query_engine/notion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from utils.query_engine.base_engine import BaseEngine
from utils.query_engine.base_qdrant_engine import BaseQdrantEngine


class NotionQueryEngine(BaseEngine):
class NotionQueryEngine(BaseQdrantEngine):
def __init__(self, community_id: str) -> None:
platform_name = "notion"
super().__init__(platform_name, community_id)

0 comments on commit 2533f36

Please sign in to comment.