From a8d7fb77f2d476f874da2a701ac4176dc341d727 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 16 May 2024 13:34:37 +0330 Subject: [PATCH 01/11] feat: Added qdrant support for BaseEngine! --- requirements.txt | 2 +- utils/query_engine/base_engine.py | 43 +++++++++++-------------------- 2 files changed, 16 insertions(+), 29 deletions(-) diff --git a/requirements.txt b/requirements.txt index d0cb913..6bab684 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,7 +15,7 @@ neo4j>=5.14.1, <6.0.0 coverage>=7.3.3, <8.0.0 pytest>=7.4.3, <8.0.0 python-dotenv==1.0.0 -tc-hivemind-backend==1.1.5 +tc-hivemind-backend==1.2.0 llama-index-question-gen-guidance==0.1.2 llama-index-vector-stores-postgres==0.1.2 celery>=5.3.6, <6.0.0 diff --git a/utils/query_engine/base_engine.py b/utils/query_engine/base_engine.py index 7f05779..fd5e7cc 100644 --- a/utils/query_engine/base_engine.py +++ b/utils/query_engine/base_engine.py @@ -1,10 +1,10 @@ -from bot.retrievers.custom_retriever import CustomVectorStoreRetriever +from llama_index.core.indices.vector_store.retrievers.retriever import ( + VectorIndexRetriever, +) from bot.retrievers.utils.load_hyperparams import load_hyperparams from llama_index.core import VectorStoreIndex, get_response_synthesizer -from llama_index.core.base.embeddings.base import BaseEmbedding from llama_index.core.query_engine import RetrieverQueryEngine -from tc_hivemind_backend.embeddings import CohereEmbedding -from tc_hivemind_backend.pg_vector_access import PGVectorAccess +from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess class BaseEngine: @@ -24,16 +24,17 @@ def __init__(self, platform_name: str, community_id: str) -> None: """ self.platform_name = platform_name self.community_id = community_id - self.dbname = f"community_{self.community_id}" + self.collection_name = f"{self.community_id}_{platform_name}" def prepare(self, testing=False): - index = self._setup_vector_store_index( + vector_store_index = self._setup_vector_store_index( testing=testing, ) _, similarity_top_k, _ = load_hyperparams() - retriever = CustomVectorStoreRetriever( - index=index, similarity_top_k=similarity_top_k + retriever = VectorIndexRetriever( + index=vector_store_index, + similarity_top_k=similarity_top_k, ) query_engine = RetrieverQueryEngine( retriever=retriever, @@ -54,27 +55,13 @@ def _setup_vector_store_index( testing : bool for testing purposes **kwargs : - table_name : str - to override the default table_name - dbname : str - to override the default database name + collection_name : str + to override the default collection_name """ - table_name = kwargs.get("table_name", self.platform_name) - dbname = kwargs.get("dbname", self.dbname) - - embed_model: BaseEmbedding - if testing: - from llama_index.core import MockEmbedding - - embed_model = MockEmbedding(embed_dim=1024) - else: - embed_model = CohereEmbedding() - - pg_vector = PGVectorAccess( - table_name=table_name, - dbname=dbname, + collection_name = kwargs.get("collection_name", self.collection_name) + qdrant_vector = QDrantVectorAccess( + collection_name=collection_name, testing=testing, - embed_model=embed_model, ) - index = pg_vector.load_index() + index = qdrant_vector.load_index() return index From 66826b44f49c3064302ec0b880966c509fb7cdd7 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 16 May 2024 13:39:53 +0330 Subject: [PATCH 02/11] fix: isort issue! --- utils/query_engine/base_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/query_engine/base_engine.py b/utils/query_engine/base_engine.py index fd5e7cc..547df28 100644 --- a/utils/query_engine/base_engine.py +++ b/utils/query_engine/base_engine.py @@ -1,8 +1,8 @@ +from bot.retrievers.utils.load_hyperparams import load_hyperparams +from llama_index.core import VectorStoreIndex, get_response_synthesizer from llama_index.core.indices.vector_store.retrievers.retriever import ( VectorIndexRetriever, ) -from bot.retrievers.utils.load_hyperparams import load_hyperparams -from llama_index.core import VectorStoreIndex, get_response_synthesizer from llama_index.core.query_engine import RetrieverQueryEngine from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess From f32a46fab37f6c5fc1e8e2604dae870e15089137 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 16 May 2024 16:03:45 +0330 Subject: [PATCH 03/11] feat: Added missing service to docker-compose.test.yaml! --- docker-compose.test.yml | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 11fff88..5ac47cf 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -34,6 +34,9 @@ services: - D_RETRIEVER_SEARCH=7 - COHERE_API_KEY=some_credentials - OPENAI_API_KEY=some_credentials2 + - QDRANT_HOST=localhost + - QDRANT_PORT=6333 + - QDRANT_API_KEY= volumes: - ./coverage:/project/coverage depends_on: @@ -43,6 +46,8 @@ services: condition: service_healthy postgres: condition: service_healthy + qdrant-healthcheck: + condition: service_healthy neo4j: image: "neo4j:5.9.0" environment: @@ -87,3 +92,25 @@ services: timeout: 30s retries: 2 start_period: 40s + qdrant: + image: qdrant/qdrant:v1.9.2 + restart: always + container_name: qdrant + ports: + - 6333:6333 + expose: + - 6333 + volumes: + - ./qdrant_data:/qdrant_data + qdrant-healthcheck: + restart: always + image: curlimages/curl:latest + entrypoint: ["/bin/sh", "-c", "--", "while true; do sleep 30; done;"] + depends_on: + - qdrant + healthcheck: + test: ["CMD", "curl", "-f", "http://qdrant:6333/readyz"] + interval: 10s + timeout: 2s + retries: 5 + From f7d62fdf86c2a15c8f65ccbd9d27743e473a84f2 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 16 May 2024 16:11:33 +0330 Subject: [PATCH 04/11] feat: Added MediaWiki query engine! --- subquery.py | 17 +++++++++++++++-- tests/unit/test_mediawiki_query_engine.py | 15 +++++++++++++++ utils/query_engine/__init__.py | 1 + utils/query_engine/media_wiki.py | 7 +++++++ 4 files changed, 38 insertions(+), 2 deletions(-) create mode 100644 tests/unit/test_mediawiki_query_engine.py create mode 100644 utils/query_engine/media_wiki.py diff --git a/subquery.py b/subquery.py index b5be207..f294051 100644 --- a/subquery.py +++ b/subquery.py @@ -11,6 +11,7 @@ DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL, GDriveQueryEngine, GitHubQueryEngine, + MediaWikiQueryEngine, NotionQueryEngine, prepare_discord_engine_auto_filter, ) @@ -72,7 +73,7 @@ def query_multiple_source( # discourse_query_engine: BaseQueryEngine gdrive_query_engine: BaseQueryEngine notion_query_engine: BaseQueryEngine - # media_wiki_query_engine: BaseQueryEngine + media_wiki_query_engine: BaseQueryEngine # telegram_query_engine: BaseQueryEngine # query engine perparation @@ -144,7 +145,19 @@ def query_multiple_source( ) ) if media_wiki: - raise NotImplementedError + mediawiki_query_engine = MediaWikiQueryEngine( + community_id=community_id + ).prepare() + tool_metadata = ToolMetadata( + name="WikiPedia", + description="Hosts articles about any information on internet", + ) + query_engine_tools.append( + QueryEngineTool( + query_engine=mediawiki_query_engine, + metadata=tool_metadata, + ) + ) embed_model = CohereEmbedding() llm = OpenAI("gpt-3.5-turbo") diff --git a/tests/unit/test_mediawiki_query_engine.py b/tests/unit/test_mediawiki_query_engine.py new file mode 100644 index 0000000..a909bc2 --- /dev/null +++ b/tests/unit/test_mediawiki_query_engine.py @@ -0,0 +1,15 @@ +from unittest import TestCase + +from bot.retrievers.custom_retriever import CustomVectorStoreRetriever +from utils.query_engine import MediaWikiQueryEngine + + +class TestMediaWikiQueryEngine(TestCase): + def setUp(self) -> None: + community_id = "sample_community" + self.notion_query_engine = MediaWikiQueryEngine(community_id) + + def test_prepare_engine(self): + notion_query_engine = self.notion_query_engine.prepare(testing=True) + print(notion_query_engine.__dict__) + self.assertIsInstance(notion_query_engine.retriever, CustomVectorStoreRetriever) diff --git a/utils/query_engine/__init__.py b/utils/query_engine/__init__.py index 0077c43..ed43657 100644 --- a/utils/query_engine/__init__.py +++ b/utils/query_engine/__init__.py @@ -1,6 +1,7 @@ # flake8: noqa from .gdrive import GDriveQueryEngine from .github import GitHubQueryEngine +from .media_wiki import MediaWikiQueryEngine from .notion import NotionQueryEngine from .prepare_discord_query_engine import prepare_discord_engine_auto_filter from .subquery_gen_prompt import DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL diff --git a/utils/query_engine/media_wiki.py b/utils/query_engine/media_wiki.py new file mode 100644 index 0000000..e842e06 --- /dev/null +++ b/utils/query_engine/media_wiki.py @@ -0,0 +1,7 @@ +from utils.query_engine.base_engine import BaseEngine + + +class MediaWikiQueryEngine(BaseEngine): + def __init__(self, community_id: str) -> None: + platform_name = "mediawiki" + super().__init__(platform_name, community_id) From 41093fba7846636058ff7829f6ec35592cb55cf2 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 16 May 2024 16:15:14 +0330 Subject: [PATCH 05/11] fix: update mediaWiki name to match modules on mongodb! --- subquery.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/subquery.py b/subquery.py index f294051..6b43c90 100644 --- a/subquery.py +++ b/subquery.py @@ -26,7 +26,7 @@ def query_multiple_source( notion: bool = False, telegram: bool = False, github: bool = False, - media_wiki: bool = False, + mediaWiki: bool = False, ) -> tuple[str, list[NodeWithScore]]: """ query multiple platforms and get an answer from the multiple @@ -73,7 +73,7 @@ def query_multiple_source( # discourse_query_engine: BaseQueryEngine gdrive_query_engine: BaseQueryEngine notion_query_engine: BaseQueryEngine - media_wiki_query_engine: BaseQueryEngine + mediawiki_query_engine: BaseQueryEngine # telegram_query_engine: BaseQueryEngine # query engine perparation @@ -144,7 +144,7 @@ def query_multiple_source( metadata=tool_metadata, ) ) - if media_wiki: + if mediaWiki: mediawiki_query_engine = MediaWikiQueryEngine( community_id=community_id ).prepare() From 9a9247dd42dcbb322e3a8aa4d1218ba9d2e72282 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 20 May 2024 11:49:13 +0330 Subject: [PATCH 06/11] fix: qdrant ip in docker-compose shouldn't be localhost! --- docker-compose.test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 5ac47cf..e1a1fb7 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -34,7 +34,7 @@ services: - D_RETRIEVER_SEARCH=7 - COHERE_API_KEY=some_credentials - OPENAI_API_KEY=some_credentials2 - - QDRANT_HOST=localhost + - QDRANT_HOST=qdrant - QDRANT_PORT=6333 - QDRANT_API_KEY= volumes: From b1f15f4e360cf227776cc0b6ece533753a34c5f6 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 20 May 2024 11:59:09 +0330 Subject: [PATCH 07/11] fix: update retriever type assertion! We were using the default `VectorIndexRetriever` for the new pipelines and it was wrong to assert those with CustomVectorStoreRetriever. --- tests/unit/test_gdrive_query_engine.py | 4 ++-- tests/unit/test_github_query_engine.py | 4 ++-- tests/unit/test_mediawiki_query_engine.py | 4 ++-- tests/unit/test_notion_query_engine.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/unit/test_gdrive_query_engine.py b/tests/unit/test_gdrive_query_engine.py index 706d211..7ab4e65 100644 --- a/tests/unit/test_gdrive_query_engine.py +++ b/tests/unit/test_gdrive_query_engine.py @@ -1,6 +1,6 @@ from unittest import TestCase -from bot.retrievers.custom_retriever import CustomVectorStoreRetriever +from llama_index.core.indices.vector_store.retrievers.retriever import VectorIndexRetriever from utils.query_engine import GDriveQueryEngine @@ -12,4 +12,4 @@ def setUp(self) -> None: def test_prepare_engine(self): gdrive_query_engine = self.gdrive_query_engine.prepare(testing=True) print(gdrive_query_engine.__dict__) - self.assertIsInstance(gdrive_query_engine.retriever, CustomVectorStoreRetriever) + self.assertIsInstance(gdrive_query_engine.retriever, VectorIndexRetriever) diff --git a/tests/unit/test_github_query_engine.py b/tests/unit/test_github_query_engine.py index 1361398..b3b657f 100644 --- a/tests/unit/test_github_query_engine.py +++ b/tests/unit/test_github_query_engine.py @@ -1,6 +1,6 @@ from unittest import TestCase -from bot.retrievers.custom_retriever import CustomVectorStoreRetriever +from llama_index.core.indices.vector_store.retrievers.retriever import VectorIndexRetriever from utils.query_engine import GitHubQueryEngine @@ -12,4 +12,4 @@ def setUp(self) -> None: def test_prepare_engine(self): github_query_engine = self.github_query_engine.prepare(testing=True) print(github_query_engine.__dict__) - self.assertIsInstance(github_query_engine.retriever, CustomVectorStoreRetriever) + self.assertIsInstance(github_query_engine.retriever, VectorIndexRetriever) diff --git a/tests/unit/test_mediawiki_query_engine.py b/tests/unit/test_mediawiki_query_engine.py index a909bc2..303aa63 100644 --- a/tests/unit/test_mediawiki_query_engine.py +++ b/tests/unit/test_mediawiki_query_engine.py @@ -1,6 +1,6 @@ from unittest import TestCase -from bot.retrievers.custom_retriever import CustomVectorStoreRetriever +from llama_index.core.indices.vector_store.retrievers.retriever import VectorIndexRetriever from utils.query_engine import MediaWikiQueryEngine @@ -12,4 +12,4 @@ def setUp(self) -> None: def test_prepare_engine(self): notion_query_engine = self.notion_query_engine.prepare(testing=True) print(notion_query_engine.__dict__) - self.assertIsInstance(notion_query_engine.retriever, CustomVectorStoreRetriever) + self.assertIsInstance(notion_query_engine.retriever, VectorIndexRetriever) diff --git a/tests/unit/test_notion_query_engine.py b/tests/unit/test_notion_query_engine.py index caaf903..57f2b96 100644 --- a/tests/unit/test_notion_query_engine.py +++ b/tests/unit/test_notion_query_engine.py @@ -1,6 +1,6 @@ from unittest import TestCase -from bot.retrievers.custom_retriever import CustomVectorStoreRetriever +from llama_index.core.indices.vector_store.retrievers.retriever import VectorIndexRetriever from utils.query_engine import NotionQueryEngine @@ -12,4 +12,4 @@ def setUp(self) -> None: def test_prepare_engine(self): notion_query_engine = self.notion_query_engine.prepare(testing=True) print(notion_query_engine.__dict__) - self.assertIsInstance(notion_query_engine.retriever, CustomVectorStoreRetriever) + self.assertIsInstance(notion_query_engine.retriever, VectorIndexRetriever) From 2533f362871038ef65bc82eeda718d4be18090dd Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 20 May 2024 12:17:56 +0330 Subject: [PATCH 08/11] fix: We still needed the PGVector engine for discord and discourse! --- tests/unit/test_base_engine.py | 4 +- utils/query_engine/base_pg_engine.py | 80 +++++++++++++++++++ .../{base_engine.py => base_qdrant_engine.py} | 4 +- utils/query_engine/gdrive.py | 4 +- utils/query_engine/github.py | 4 +- .../level_based_platform_query_engine.py | 6 +- utils/query_engine/media_wiki.py | 4 +- utils/query_engine/notion.py | 4 +- 8 files changed, 95 insertions(+), 15 deletions(-) create mode 100644 utils/query_engine/base_pg_engine.py rename utils/query_engine/{base_engine.py => base_qdrant_engine.py} (95%) diff --git a/tests/unit/test_base_engine.py b/tests/unit/test_base_engine.py index 43ef52b..14dd60e 100644 --- a/tests/unit/test_base_engine.py +++ b/tests/unit/test_base_engine.py @@ -1,6 +1,6 @@ from unittest import TestCase -from utils.query_engine.base_engine import BaseEngine +from utils.query_engine.base_qdrant_engine import BaseQdrantEngine class TestBaseEngine(TestCase): @@ -11,7 +11,7 @@ def test_setup_vector_store_index(self): """ platform_table_name = "test_table" community_id = "123456" - base_engine = BaseEngine( + base_engine = BaseQdrantEngine( platform_name=platform_table_name, community_id=community_id, ) diff --git a/utils/query_engine/base_pg_engine.py b/utils/query_engine/base_pg_engine.py new file mode 100644 index 0000000..1f2e415 --- /dev/null +++ b/utils/query_engine/base_pg_engine.py @@ -0,0 +1,80 @@ +from bot.retrievers.custom_retriever import CustomVectorStoreRetriever +from bot.retrievers.utils.load_hyperparams import load_hyperparams +from llama_index.core import VectorStoreIndex, get_response_synthesizer +from llama_index.core.base.embeddings.base import BaseEmbedding +from llama_index.core.query_engine import RetrieverQueryEngine +from tc_hivemind_backend.embeddings import CohereEmbedding +from tc_hivemind_backend.pg_vector_access import PGVectorAccess + + +class BasePGEngine: + def __init__(self, platform_name: str, community_id: str) -> None: + """ + initialize the pg vector db engine to query the database related to a community + and the table related to the platform + + Parameters + ----------- + platform_name : str + the table representative of data for a specific platform + community_id : str + the database for a community + normally the community database is saved as + `community_{community_id}` + """ + self.platform_name = platform_name + self.community_id = community_id + self.dbname = f"community_{self.community_id}" + + def prepare(self, testing=False): + index = self._setup_vector_store_index( + testing=testing, + ) + _, similarity_top_k, _ = load_hyperparams() + + retriever = CustomVectorStoreRetriever( + index=index, similarity_top_k=similarity_top_k + ) + query_engine = RetrieverQueryEngine( + retriever=retriever, + response_synthesizer=get_response_synthesizer(), + ) + return query_engine + + def _setup_vector_store_index( + self, + testing: bool = False, + **kwargs, + ) -> VectorStoreIndex: + """ + prepare the vector_store for querying data + + Parameters + ------------ + testing : bool + for testing purposes + **kwargs : + table_name : str + to override the default table_name + dbname : str + to override the default database name + """ + table_name = kwargs.get("table_name", self.platform_name) + dbname = kwargs.get("dbname", self.dbname) + + embed_model: BaseEmbedding + if testing: + from llama_index.core import MockEmbedding + + embed_model = MockEmbedding(embed_dim=1024) + else: + embed_model = CohereEmbedding() + + pg_vector = PGVectorAccess( + table_name=table_name, + dbname=dbname, + testing=testing, + embed_model=embed_model, + ) + index = pg_vector.load_index() + return index \ No newline at end of file diff --git a/utils/query_engine/base_engine.py b/utils/query_engine/base_qdrant_engine.py similarity index 95% rename from utils/query_engine/base_engine.py rename to utils/query_engine/base_qdrant_engine.py index 547df28..2c89896 100644 --- a/utils/query_engine/base_engine.py +++ b/utils/query_engine/base_qdrant_engine.py @@ -7,10 +7,10 @@ from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess -class BaseEngine: +class BaseQdrantEngine: def __init__(self, platform_name: str, community_id: str) -> None: """ - initialize the engine to query the database related to a community + initialize the qdrant db engine to query the database related to a community and the table related to the platform Parameters diff --git a/utils/query_engine/gdrive.py b/utils/query_engine/gdrive.py index 7be4bc8..114984e 100644 --- a/utils/query_engine/gdrive.py +++ b/utils/query_engine/gdrive.py @@ -1,7 +1,7 @@ -from utils.query_engine.base_engine import BaseEngine +from utils.query_engine.base_qdrant_engine import BaseQdrantEngine -class GDriveQueryEngine(BaseEngine): +class GDriveQueryEngine(BaseQdrantEngine): def __init__(self, community_id: str) -> None: platform_name = "gdrive" super().__init__(platform_name, community_id) diff --git a/utils/query_engine/github.py b/utils/query_engine/github.py index ad91ead..beaefa1 100644 --- a/utils/query_engine/github.py +++ b/utils/query_engine/github.py @@ -1,7 +1,7 @@ -from utils.query_engine.base_engine import BaseEngine +from utils.query_engine.base_qdrant_engine import BaseQdrantEngine -class GitHubQueryEngine(BaseEngine): +class GitHubQueryEngine(BaseQdrantEngine): def __init__(self, community_id: str) -> None: platform_name = "github" super().__init__(platform_name, community_id) diff --git a/utils/query_engine/level_based_platform_query_engine.py b/utils/query_engine/level_based_platform_query_engine.py index 6c0edf7..da51ac8 100644 --- a/utils/query_engine/level_based_platform_query_engine.py +++ b/utils/query_engine/level_based_platform_query_engine.py @@ -13,7 +13,7 @@ from llama_index.core.retrievers import BaseRetriever from llama_index.core.schema import NodeWithScore from llama_index.llms.openai import OpenAI -from utils.query_engine.base_engine import BaseEngine +from utils.query_engine.base_pg_engine import BasePGEngine from utils.query_engine.level_based_platforms_util import LevelBasedPlatformUtils qa_prompt = PromptTemplate( @@ -116,7 +116,7 @@ def prepare_platform_engine( ) llm = kwargs.get("llm", OpenAI("gpt-4")) qa_prompt_ = kwargs.get("qa_prompt", qa_prompt) - base_engine = BaseEngine(platform_table_name, community_id) + base_engine = BasePGEngine(platform_table_name, community_id) index: VectorStoreIndex = kwargs.get( "index_raw", base_engine._setup_vector_store_index( @@ -198,7 +198,7 @@ def prepare_engine_auto_filter( dbname = f"community_{community_id}" summary_similarity_top_k, _, d = load_hyperparams() - base_engine = BaseEngine( + base_engine = BasePGEngine( platform_table_name + "_summary", community_id, ) diff --git a/utils/query_engine/media_wiki.py b/utils/query_engine/media_wiki.py index e842e06..1bbc8bb 100644 --- a/utils/query_engine/media_wiki.py +++ b/utils/query_engine/media_wiki.py @@ -1,7 +1,7 @@ -from utils.query_engine.base_engine import BaseEngine +from utils.query_engine.base_qdrant_engine import BaseQdrantEngine -class MediaWikiQueryEngine(BaseEngine): +class MediaWikiQueryEngine(BaseQdrantEngine): def __init__(self, community_id: str) -> None: platform_name = "mediawiki" super().__init__(platform_name, community_id) diff --git a/utils/query_engine/notion.py b/utils/query_engine/notion.py index c32aebd..5e1a3c2 100644 --- a/utils/query_engine/notion.py +++ b/utils/query_engine/notion.py @@ -1,7 +1,7 @@ -from utils.query_engine.base_engine import BaseEngine +from utils.query_engine.base_qdrant_engine import BaseQdrantEngine -class NotionQueryEngine(BaseEngine): +class NotionQueryEngine(BaseQdrantEngine): def __init__(self, community_id: str) -> None: platform_name = "notion" super().__init__(platform_name, community_id) From 83aa68338f243d5eea55857567e1b94eabae76df Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 20 May 2024 12:19:36 +0330 Subject: [PATCH 09/11] fix: getting back the test case for pgvector engine! getting it back from main branch but just a couple of renames --- tests/unit/test_base_pg_engine.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 tests/unit/test_base_pg_engine.py diff --git a/tests/unit/test_base_pg_engine.py b/tests/unit/test_base_pg_engine.py new file mode 100644 index 0000000..15c2c7b --- /dev/null +++ b/tests/unit/test_base_pg_engine.py @@ -0,0 +1,24 @@ +from unittest import TestCase + +from utils.query_engine.base_pg_engine import BasePGEngine + + +class TestPGBaseEngine(TestCase): + def test_setup_vector_store_index(self): + """ + Tests that _setup_vector_store_index creates a PGVectorAccess object + and calls its load_index method. + """ + platform_table_name = "test_table" + community_id = "123456" + base_engine = BasePGEngine( + platform_name=platform_table_name, + community_id=community_id, + ) + base_engine = base_engine._setup_vector_store_index( + testing=True, + ) + + expected_dbname = f"community_{community_id}" + self.assertIn(expected_dbname, base_engine.vector_store.connection_string) + self.assertEqual(base_engine.vector_store.table_name, platform_table_name) \ No newline at end of file From 04581b9c73178010115a85679a8e0e9993e6402d Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 20 May 2024 12:28:25 +0330 Subject: [PATCH 10/11] fix: qdrant properties! qdrant engine had different property from pg engine we had earlier! --- .../{test_base_engine.py => test_base_qdrant_engine.py} | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) rename tests/unit/{test_base_engine.py => test_base_qdrant_engine.py} (70%) diff --git a/tests/unit/test_base_engine.py b/tests/unit/test_base_qdrant_engine.py similarity index 70% rename from tests/unit/test_base_engine.py rename to tests/unit/test_base_qdrant_engine.py index 14dd60e..b3bca46 100644 --- a/tests/unit/test_base_engine.py +++ b/tests/unit/test_base_qdrant_engine.py @@ -3,7 +3,7 @@ from utils.query_engine.base_qdrant_engine import BaseQdrantEngine -class TestBaseEngine(TestCase): +class TestBaseQdrantEngine(TestCase): def test_setup_vector_store_index(self): """ Tests that _setup_vector_store_index creates a PGVectorAccess object @@ -19,6 +19,7 @@ def test_setup_vector_store_index(self): testing=True, ) - expected_dbname = f"community_{community_id}" - self.assertIn(expected_dbname, base_engine.vector_store.connection_string) - self.assertEqual(base_engine.vector_store.table_name, platform_table_name) + expected_collection_name = f"{community_id}_{platform_table_name}" + self.assertEqual( + base_engine.vector_store.collection_name, expected_collection_name + ) From 62e3a2284e99144cd052afaa9afb59dae133d756 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 20 May 2024 12:35:11 +0330 Subject: [PATCH 11/11] fix: linter issues! --- tests/unit/test_base_pg_engine.py | 2 +- tests/unit/test_gdrive_query_engine.py | 4 +++- tests/unit/test_github_query_engine.py | 4 +++- tests/unit/test_mediawiki_query_engine.py | 4 +++- tests/unit/test_notion_query_engine.py | 4 +++- utils/query_engine/base_pg_engine.py | 2 +- 6 files changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/unit/test_base_pg_engine.py b/tests/unit/test_base_pg_engine.py index 15c2c7b..1e2c815 100644 --- a/tests/unit/test_base_pg_engine.py +++ b/tests/unit/test_base_pg_engine.py @@ -21,4 +21,4 @@ def test_setup_vector_store_index(self): expected_dbname = f"community_{community_id}" self.assertIn(expected_dbname, base_engine.vector_store.connection_string) - self.assertEqual(base_engine.vector_store.table_name, platform_table_name) \ No newline at end of file + self.assertEqual(base_engine.vector_store.table_name, platform_table_name) diff --git a/tests/unit/test_gdrive_query_engine.py b/tests/unit/test_gdrive_query_engine.py index 7ab4e65..3aac35b 100644 --- a/tests/unit/test_gdrive_query_engine.py +++ b/tests/unit/test_gdrive_query_engine.py @@ -1,6 +1,8 @@ from unittest import TestCase -from llama_index.core.indices.vector_store.retrievers.retriever import VectorIndexRetriever +from llama_index.core.indices.vector_store.retrievers.retriever import ( + VectorIndexRetriever, +) from utils.query_engine import GDriveQueryEngine diff --git a/tests/unit/test_github_query_engine.py b/tests/unit/test_github_query_engine.py index b3b657f..3d77abb 100644 --- a/tests/unit/test_github_query_engine.py +++ b/tests/unit/test_github_query_engine.py @@ -1,6 +1,8 @@ from unittest import TestCase -from llama_index.core.indices.vector_store.retrievers.retriever import VectorIndexRetriever +from llama_index.core.indices.vector_store.retrievers.retriever import ( + VectorIndexRetriever, +) from utils.query_engine import GitHubQueryEngine diff --git a/tests/unit/test_mediawiki_query_engine.py b/tests/unit/test_mediawiki_query_engine.py index 303aa63..235a0ae 100644 --- a/tests/unit/test_mediawiki_query_engine.py +++ b/tests/unit/test_mediawiki_query_engine.py @@ -1,6 +1,8 @@ from unittest import TestCase -from llama_index.core.indices.vector_store.retrievers.retriever import VectorIndexRetriever +from llama_index.core.indices.vector_store.retrievers.retriever import ( + VectorIndexRetriever, +) from utils.query_engine import MediaWikiQueryEngine diff --git a/tests/unit/test_notion_query_engine.py b/tests/unit/test_notion_query_engine.py index 57f2b96..05587a6 100644 --- a/tests/unit/test_notion_query_engine.py +++ b/tests/unit/test_notion_query_engine.py @@ -1,6 +1,8 @@ from unittest import TestCase -from llama_index.core.indices.vector_store.retrievers.retriever import VectorIndexRetriever +from llama_index.core.indices.vector_store.retrievers.retriever import ( + VectorIndexRetriever, +) from utils.query_engine import NotionQueryEngine diff --git a/utils/query_engine/base_pg_engine.py b/utils/query_engine/base_pg_engine.py index 1f2e415..7b2d16a 100644 --- a/utils/query_engine/base_pg_engine.py +++ b/utils/query_engine/base_pg_engine.py @@ -77,4 +77,4 @@ def _setup_vector_store_index( embed_model=embed_model, ) index = pg_vector.load_index() - return index \ No newline at end of file + return index