From dac2b9d35d17ba868e7f84ff57815133086c1fe0 Mon Sep 17 00:00:00 2001 From: Kumaran Rajendhiran Date: Thu, 23 Jan 2025 13:57:20 +0000 Subject: [PATCH] Use skip_on_missing_imports to mark tests in test/agentchat/contrib/vectordb files --- .../contrib/vectordb/test_chromadb.py | 7 ++-- .../contrib/vectordb/test_mongodb.py | 32 +++++++++++-------- .../contrib/vectordb/test_pgvectordb.py | 9 ++---- .../agentchat/contrib/vectordb/test_qdrant.py | 10 ++---- 4 files changed, 25 insertions(+), 33 deletions(-) diff --git a/test/agentchat/contrib/vectordb/test_chromadb.py b/test/agentchat/contrib/vectordb/test_chromadb.py index c45f7f1175..48c67a7931 100644 --- a/test/agentchat/contrib/vectordb/test_chromadb.py +++ b/test/agentchat/contrib/vectordb/test_chromadb.py @@ -10,7 +10,7 @@ import pytest from autogen.agentchat.contrib.vectordb.chromadb import ChromaVectorDB -from autogen.import_utils import optional_import_block +from autogen.import_utils import optional_import_block, skip_on_missing_imports sys.path.append(os.path.join(os.path.dirname(__file__), "..")) @@ -20,10 +20,7 @@ import sentence_transformers # noqa: F401 -skip = not result.is_successful - - -@pytest.mark.skipif(skip, reason="dependency is not installed") +@skip_on_missing_imports(["chromadb", "sentence_transformers"], "retrievechat") def test_chromadb(): # test create collection db = ChromaVectorDB(path=".db") diff --git a/test/agentchat/contrib/vectordb/test_mongodb.py b/test/agentchat/contrib/vectordb/test_mongodb.py index 76877e6109..0bf9379ac1 100644 --- a/test/agentchat/contrib/vectordb/test_mongodb.py +++ b/test/agentchat/contrib/vectordb/test_mongodb.py @@ -13,22 +13,12 @@ from autogen.agentchat.contrib.vectordb.base import Document from autogen.agentchat.contrib.vectordb.mongodb import MongoDBAtlasVectorDB -from autogen.import_utils import optional_import_block +from autogen.import_utils import optional_import_block, skip_on_missing_imports with optional_import_block() as result: - import pymongo # noqa: F401 - import sentence_transformers # noqa: F401 - - -if not result.is_successful: - # To display warning in pyproject.toml [tool.pytest.ini_options] set log_cli = true - logger = logging.getLogger(__name__) - logger.warning(f"skipping {__name__}. It requires one to pip install pymongo or the extra [retrievechat-mongodb]") - pytest.skip(allow_module_level=True) - -from pymongo import MongoClient -from pymongo.collection import Collection -from pymongo.errors import OperationFailure + from pymongo import MongoClient + from pymongo.collection import Collection + from pymongo.errors import OperationFailure logger = logging.getLogger(__name__) @@ -143,6 +133,7 @@ def collection_name(): return f"{MONGODB_COLLECTION}_{collection_id}" +@skip_on_missing_imports(["pymongo", "sentence_transformers"], "retrievechat-mongodb") def test_create_collection(db, collection_name): """Def create_collection(collection_name: str, overwrite: bool = False) -> Collection @@ -172,6 +163,7 @@ def test_create_collection(db, collection_name): db.create_collection(collection_name=collection_name, overwrite=False, get_or_create=False) +@skip_on_missing_imports(["pymongo", "sentence_transformers"], "retrievechat-mongodb") def test_get_collection(db, collection_name): with pytest.raises(ValueError): db.get_collection() @@ -185,6 +177,7 @@ def test_get_collection(db, collection_name): assert collection_got.name == db.active_collection.name +@skip_on_missing_imports(["pymongo", "sentence_transformers"], "retrievechat-mongodb") def test_delete_collection(db, collection_name): assert collection_name not in db.list_collections() collection = db.create_collection(collection_name) @@ -193,6 +186,7 @@ def test_delete_collection(db, collection_name): assert collection_name not in db.list_collections() +@skip_on_missing_imports(["pymongo", "sentence_transformers"], "retrievechat-mongodb") def test_insert_docs(db, collection_name, example_documents): # Test that there's an active collection with pytest.raises(ValueError) as exc: @@ -218,6 +212,7 @@ def test_insert_docs(db, collection_name, example_documents): assert len(found[0]["embedding"]) == 384 +@skip_on_missing_imports(["pymongo", "sentence_transformers"], "retrievechat-mongodb") def test_update_docs(db_with_indexed_clxn, example_documents): db, collection = db_with_indexed_clxn # Use update_docs to insert new documents @@ -253,6 +248,7 @@ def test_update_docs(db_with_indexed_clxn, example_documents): assert collection.find_one({"_id": new_id}) is None +@skip_on_missing_imports(["pymongo", "sentence_transformers"], "retrievechat-mongodb") def test_delete_docs(db_with_indexed_clxn, example_documents): db, clxn = db_with_indexed_clxn # Insert example documents @@ -263,6 +259,7 @@ def test_delete_docs(db_with_indexed_clxn, example_documents): assert {2, "2"} == {doc["_id"] for doc in clxn.find({})} +@skip_on_missing_imports(["pymongo", "sentence_transformers"], "retrievechat-mongodb") def test_get_docs_by_ids(db_with_indexed_clxn, example_documents): db, clxn = db_with_indexed_clxn # Insert example documents @@ -288,11 +285,13 @@ def test_get_docs_by_ids(db_with_indexed_clxn, example_documents): assert len(docs) == 4 +@skip_on_missing_imports(["pymongo", "sentence_transformers"], "retrievechat-mongodb") def test_retrieve_docs_empty(db_with_indexed_clxn): db, clxn = db_with_indexed_clxn assert db.retrieve_docs(queries=["Cats"], collection_name=clxn.name, n_results=2) == [] +@skip_on_missing_imports(["pymongo", "sentence_transformers"], "retrievechat-mongodb") def test_retrieve_docs_populated_db_empty_query(db_with_indexed_clxn, example_documents): db, clxn = db_with_indexed_clxn db.insert_docs(example_documents, collection_name=clxn.name) @@ -301,6 +300,7 @@ def test_retrieve_docs_populated_db_empty_query(db_with_indexed_clxn, example_do assert results == [] +@skip_on_missing_imports(["pymongo", "sentence_transformers"], "retrievechat-mongodb") def test_retrieve_docs(db_with_indexed_clxn, example_documents): """Begin testing Atlas Vector Search NOTE: Indexing may take some time, so we must be patient on the first query. @@ -324,6 +324,7 @@ def results_ready(): assert all(["embedding" not in doc[0] for doc in results[0]]) +@skip_on_missing_imports(["pymongo", "sentence_transformers"], "retrievechat-mongodb") def test_retrieve_docs_with_embedding(db_with_indexed_clxn, example_documents): """Begin testing Atlas Vector Search NOTE: Indexing may take some time, so we must be patient on the first query. @@ -347,6 +348,7 @@ def results_ready(): assert all(["embedding" in doc[0] for doc in results[0]]) +@skip_on_missing_imports(["pymongo", "sentence_transformers"], "retrievechat-mongodb") def test_retrieve_docs_multiple_queries(db_with_indexed_clxn, example_documents): db, clxn = db_with_indexed_clxn # Insert example documents @@ -369,6 +371,7 @@ def results_ready(): assert {doc[0]["id"] for doc in results[1]} == {"1", "2"} +@skip_on_missing_imports(["pymongo", "sentence_transformers"], "retrievechat-mongodb") def test_retrieve_docs_with_threshold(db_with_indexed_clxn, example_documents): db, clxn = db_with_indexed_clxn # Insert example documents @@ -390,6 +393,7 @@ def results_ready(): assert all([doc[1] >= 0.7 for doc in results[0]]) +@skip_on_missing_imports(["pymongo", "sentence_transformers"], "retrievechat-mongodb") def test_wait_until_document_ready(collection_name, example_documents): database = MongoClient(MONGODB_URI)[MONGODB_DATABASE] _empty_collections_and_delete_indexes(database, [collection_name], wait=True) diff --git a/test/agentchat/contrib/vectordb/test_pgvectordb.py b/test/agentchat/contrib/vectordb/test_pgvectordb.py index 47b6e6e50f..efff6359fd 100644 --- a/test/agentchat/contrib/vectordb/test_pgvectordb.py +++ b/test/agentchat/contrib/vectordb/test_pgvectordb.py @@ -11,25 +11,22 @@ import pytest from autogen.agentchat.contrib.vectordb.pgvectordb import PGVectorDB -from autogen.import_utils import optional_import_block +from autogen.import_utils import optional_import_block, skip_on_missing_imports from ....conftest import reason with optional_import_block() as result: - import pgvector # noqa: F401 import psycopg - import sentence_transformers # noqa: F401 -skip = not result.is_successful - reason = "do not run on MacOS or windows OR dependency is not installed OR " + reason @pytest.mark.skipif( - sys.platform in ["darwin", "win32"] or skip, + sys.platform in ["darwin", "win32"], reason=reason, ) +@skip_on_missing_imports(["pgvector", "psycopg", "sentence_transformers"], "retrievechat-pgvector") def test_pgvector(): # test db config db_config = { diff --git a/test/agentchat/contrib/vectordb/test_qdrant.py b/test/agentchat/contrib/vectordb/test_qdrant.py index 20870edddc..5e0cb5dec8 100644 --- a/test/agentchat/contrib/vectordb/test_qdrant.py +++ b/test/agentchat/contrib/vectordb/test_qdrant.py @@ -8,22 +8,16 @@ import sys import uuid -import pytest - from autogen.agentchat.contrib.vectordb.qdrant import QdrantVectorDB -from autogen.import_utils import optional_import_block +from autogen.import_utils import optional_import_block, skip_on_missing_imports sys.path.append(os.path.join(os.path.dirname(__file__), "..")) with optional_import_block() as result: - from fastembed import TextEmbedding # noqa: F401 from qdrant_client import QdrantClient -skip = not result.is_successful - - -@pytest.mark.skipif(skip, reason="dependency is not installed") +@skip_on_missing_imports(["fastembed", "qdrant_client"], "retrievechat-qdrant") def test_qdrant(): # test create collection client = QdrantClient(location=":memory:")