Skip to content

Commit

Permalink
fix: Added checks for collection existance!
Browse files Browse the repository at this point in the history
in order to avoid any errors raising and decoupled from ETL work.
  • Loading branch information
amindadgar committed May 27, 2024
1 parent 5c940ed commit 05f78fb
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 5 deletions.
15 changes: 10 additions & 5 deletions subquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
NotionQueryEngine,
prepare_discord_engine_auto_filter,
)
from utils.qdrant_utils import QDrantUtils


def query_multiple_source(
Expand Down Expand Up @@ -67,6 +68,7 @@ def query_multiple_source(
"""
query_engine_tools: list[QueryEngineTool] = []
tools: list[ToolMetadata] = []
qdrant_utils = QDrantUtils(community_id)

discord_query_engine: BaseQueryEngine
github_query_engine: BaseQueryEngine
Expand All @@ -76,6 +78,9 @@ def query_multiple_source(
mediawiki_query_engine: BaseQueryEngine
# telegram_query_engine: BaseQueryEngine

# wrapper for more clarity
check_collection = qdrant_utils.chech_collection_exist

# query engine perparation
# tools_metadata and query_engine_tools
if discord:
Expand All @@ -98,7 +103,7 @@ def query_multiple_source(

if discourse:
raise NotImplementedError
if gdrive:
if gdrive and check_collection("gdrive"):
gdrive_query_engine = GDriveQueryEngine(community_id=community_id).prepare()
tool_metadata = ToolMetadata(
name="Google-Drive",
Expand All @@ -113,7 +118,7 @@ def query_multiple_source(
metadata=tool_metadata,
)
)
if notion:
if notion and check_collection("notion"):
notion_query_engine = NotionQueryEngine(community_id=community_id).prepare()
tool_metadata = ToolMetadata(
name="Notion",
Expand All @@ -127,9 +132,9 @@ def query_multiple_source(
metadata=tool_metadata,
)
)
if telegram:
if telegram and check_collection("telegram"):
raise NotImplementedError
if github:
if github and check_collection("github"):
github_query_engine = GitHubQueryEngine(community_id=community_id).prepare()
tool_metadata = ToolMetadata(
name="GitHub",
Expand All @@ -144,7 +149,7 @@ def query_multiple_source(
metadata=tool_metadata,
)
)
if mediaWiki:
if mediaWiki and check_collection("mediawiki"):
mediawiki_query_engine = MediaWikiQueryEngine(
community_id=community_id
).prepare()
Expand Down
79 changes: 79 additions & 0 deletions tests/integration/test_qdrant_collection_available.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from unittest import TestCase

from qdrant_client import models
from tc_hivemind_backend.db.qdrant import QdrantSingleton
from utils.qdrant_utils import QDrantUtils


class TestQDrantAvailableCollection(TestCase):
def setUp(self) -> None:
self.community_id = "community_sample"
self.qdrant_client = QdrantSingleton.get_instance().get_client()
self.qdrant_utils = QDrantUtils(self.community_id)

# deleting all collections
collections = self.qdrant_client.get_collections()
for col in collections.collections:
self.qdrant_client.delete_collection(col.name)

def test_no_collection_available(self):
platform = "platform1"
available = self.qdrant_utils.chech_collection_exist(platform)

self.assertIsInstance(available, bool)
self.assertFalse(available)

def test_single_collection_available(self):
platform = "platform1"
collection_name = f"{self.community_id}_{platform}"
self.qdrant_client.create_collection(
collection_name,
vectors_config=models.VectorParams(
size=100, distance=models.Distance.COSINE
),
)
available = self.qdrant_utils.chech_collection_exist(platform)

self.assertIsInstance(available, bool)
self.assertTrue(available)

def test_multiple_collections_but_not_input(self):
"""
test if there was multiple collections available
but it isn't the collection we want to check for
"""
platforms = ["platform1", "platform2", "platform3"]
for plt in platforms:
collection_name = f"{self.community_id}_{plt}"
self.qdrant_client.create_collection(
collection_name,
vectors_config=models.VectorParams(
size=100, distance=models.Distance.COSINE
),
)

available = self.qdrant_utils.chech_collection_exist("platform4")

self.assertIsInstance(available, bool)
self.assertFalse(available)

def test_multiple_collections_available_given_input(self):
"""
test multiple collections available with given input
"""
platforms = ["platform1", "platform2", "platform3"]
for plt in platforms:
collection_name = f"{self.community_id}_{plt}"
self.qdrant_client.create_collection(
collection_name,
vectors_config=models.VectorParams(
size=100, distance=models.Distance.COSINE
),
)

available = self.qdrant_utils.chech_collection_exist(
platforms[0],
)

self.assertIsInstance(available, bool)
self.assertTrue(available)
34 changes: 34 additions & 0 deletions utils/qdrant_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from tc_hivemind_backend.db.qdrant import QdrantSingleton


class QDrantUtils:
def __init__(self, community_id: str) -> None:
"""
setup qdrant utils for a specific community
Parameters
------------
community_id : str
the community we want to initialize the utils for
"""
self.qdrant_client = QdrantSingleton.get_instance().get_client()
self.community_id = community_id

def chech_collection_exist(self, platform_name: str) -> bool:
"""
check if the collection exist on qdrant database
Parameters
-----------
platform_name : str
the platform name we want to check for its collection availability
Returns
--------
available : bool
if the collection was available True, else would be False
"""
collection_name = f"{self.community_id}_{platform_name}"
available = self.qdrant_client.collection_exists(collection_name)
return available

0 comments on commit 05f78fb

Please sign in to comment.