From f3b0bffe6a89c69b7a3d429337b2abf204509267 Mon Sep 17 00:00:00 2001 From: Anush Date: Fri, 13 Sep 2024 18:01:17 +0530 Subject: [PATCH] test: Integration tests for upserting with shard keys (#32) --- src/test/python/conftest.py | 26 ++++++-- src/test/python/test_qdrant_ingest.py | 89 ++++++++++++++++++++++----- 2 files changed, 96 insertions(+), 19 deletions(-) diff --git a/src/test/python/conftest.py b/src/test/python/conftest.py index d590f1b..9d05637 100644 --- a/src/test/python/conftest.py +++ b/src/test/python/conftest.py @@ -1,16 +1,19 @@ import pytest -from testcontainers.qdrant import QdrantContainer +from testcontainers.core.container import DockerContainer from qdrant_client import QdrantClient, models import uuid from pyspark.sql import SparkSession from typing import NamedTuple from uuid import uuid4 +from testcontainers.core.waiting_utils import wait_for_logs QDRANT_GRPC_PORT = 6334 QDRANT_EMBEDDING_DIM = 6 QDRANT_DISTANCE = models.Distance.COSINE QDRANT_API_KEY = uuid4().hex +STRING_SHARD_KEY = "string_shard_key" +INTEGER_SHARD_KEY = 876 class Qdrant(NamedTuple): @@ -20,7 +23,15 @@ class Qdrant(NamedTuple): client: QdrantClient -qdrant_container = QdrantContainer(image="qdrant/qdrant:latest", api_key=QDRANT_API_KEY) +qdrant_container = ( + ( + DockerContainer(image="qdrant/qdrant:latest") + .with_env("QDRANT__SERVICE__API_KEY", QDRANT_API_KEY) + .with_env("QDRANT__CLUSTER__ENABLED", "true") + ) + .with_command("./qdrant --uri http://qdrant_node_1:6335") + .with_exposed_ports(QDRANT_GRPC_PORT) +) # Reference: https://gist.github.com/dizzythinks/f3bb37fd8ab1484bfec79d39ad8a92d3 @@ -38,6 +49,7 @@ def get_pom_version(): @pytest.fixture(scope="module", autouse=True) def setup_container(request): qdrant_container.start() + wait_for_logs(qdrant_container, "Qdrant gRPC listening on 6334") def remove_container(): qdrant_container.stop() @@ -92,14 +104,20 @@ def qdrant(): "multi": models.VectorParams( size=QDRANT_EMBEDDING_DIM, distance=QDRANT_DISTANCE, - multivector_config=models.MultiVectorConfig(comparator=models.MultiVectorComparator.MAX_SIM) - ) + multivector_config=models.MultiVectorConfig( + comparator=models.MultiVectorComparator.MAX_SIM + ), + ), }, sparse_vectors_config={ "sparse": models.SparseVectorParams(), "another_sparse": models.SparseVectorParams(), }, + sharding_method=models.ShardingMethod.CUSTOM, ) + + client.create_shard_key(collection_name, STRING_SHARD_KEY) + client.create_shard_key(collection_name, INTEGER_SHARD_KEY) yield Qdrant( url=f"http://{host}:{grpc_port}", diff --git a/src/test/python/test_qdrant_ingest.py b/src/test/python/test_qdrant_ingest.py index f9d2e72..d06f8b2 100644 --- a/src/test/python/test_qdrant_ingest.py +++ b/src/test/python/test_qdrant_ingest.py @@ -2,7 +2,7 @@ from pyspark.sql import SparkSession from .schema import schema -from .conftest import Qdrant +from .conftest import Qdrant, STRING_SHARD_KEY, INTEGER_SHARD_KEY current_directory = os.path.dirname(__file__) input_file_path = os.path.join(current_directory, "..", "resources", "users.json") @@ -20,12 +20,16 @@ def test_upsert_unnamed_vectors(qdrant: Qdrant, spark_session: SparkSession): "embedding_field": "dense_vector", "api_key": qdrant.api_key, "schema": df.schema.json(), + "shard_key_selector": STRING_SHARD_KEY, } df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save() assert ( - qdrant.client.count(qdrant.collection_name).count == df.count() + qdrant.client.count( + qdrant.collection_name, shard_key_selector=STRING_SHARD_KEY + ).count + == df.count() ), "Uploaded points count is not equal to the dataframe count" @@ -41,12 +45,16 @@ def test_upsert_named_vectors(qdrant: Qdrant, spark_session: SparkSession): "vector_name": "dense", "schema": df.schema.json(), "api_key": qdrant.api_key, + "shard_key_selector": STRING_SHARD_KEY, } df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save() assert ( - qdrant.client.count(qdrant.collection_name).count == df.count() + qdrant.client.count( + qdrant.collection_name, shard_key_selector=STRING_SHARD_KEY + ).count + == df.count() ), "Uploaded points count is not equal to the dataframe count" @@ -65,12 +73,16 @@ def test_upsert_multiple_named_dense_vectors( "vector_names": "dense,another_dense", "schema": df.schema.json(), "api_key": qdrant.api_key, + "shard_key_selector": STRING_SHARD_KEY, } df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save() assert ( - qdrant.client.count(qdrant.collection_name).count == df.count() + qdrant.client.count( + qdrant.collection_name, shard_key_selector=STRING_SHARD_KEY + ).count + == df.count() ), "Uploaded points count is not equal to the dataframe count" @@ -88,12 +100,16 @@ def test_upsert_sparse_vectors(qdrant: Qdrant, spark_session: SparkSession): "sparse_vector_names": "sparse", "schema": df.schema.json(), "api_key": qdrant.api_key, + "shard_key_selector": STRING_SHARD_KEY, } df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save() assert ( - qdrant.client.count(qdrant.collection_name).count == df.count() + qdrant.client.count( + qdrant.collection_name, shard_key_selector=STRING_SHARD_KEY + ).count + == df.count() ), "Uploaded points count is not equal to the dataframe count" @@ -111,12 +127,16 @@ def test_upsert_multiple_sparse_vectors(qdrant: Qdrant, spark_session: SparkSess "sparse_vector_names": "sparse,another_sparse", "schema": df.schema.json(), "api_key": qdrant.api_key, + "shard_key_selector": STRING_SHARD_KEY, } df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save() assert ( - qdrant.client.count(qdrant.collection_name).count == df.count() + qdrant.client.count( + qdrant.collection_name, shard_key_selector=STRING_SHARD_KEY + ).count + == df.count() ), "Uploaded points count is not equal to the dataframe count" @@ -136,12 +156,16 @@ def test_upsert_sparse_named_dense_vectors(qdrant: Qdrant, spark_session: SparkS "sparse_vector_names": "sparse", "schema": df.schema.json(), "api_key": qdrant.api_key, + "shard_key_selector": STRING_SHARD_KEY, } df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save() assert ( - qdrant.client.count(qdrant.collection_name).count == df.count() + qdrant.client.count( + qdrant.collection_name, shard_key_selector=STRING_SHARD_KEY + ).count + == df.count() ), "Uploaded points count is not equal to the dataframe count" @@ -162,12 +186,16 @@ def test_upsert_sparse_unnamed_dense_vectors( "sparse_vector_names": "sparse", "schema": df.schema.json(), "api_key": qdrant.api_key, + "shard_key_selector": INTEGER_SHARD_KEY, } df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save() assert ( - qdrant.client.count(qdrant.collection_name).count == df.count() + qdrant.client.count( + qdrant.collection_name, shard_key_selector=INTEGER_SHARD_KEY + ).count + == df.count() ), "Uploaded points count is not equal to the dataframe count" @@ -189,17 +217,20 @@ def test_upsert_multiple_sparse_dense_vectors( "sparse_vector_names": "sparse,another_sparse", "schema": df.schema.json(), "api_key": qdrant.api_key, + "shard_key_selector": INTEGER_SHARD_KEY, } df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save() assert ( - qdrant.client.count(qdrant.collection_name).count == df.count() + qdrant.client.count( + qdrant.collection_name, shard_key_selector=INTEGER_SHARD_KEY + ).count + == df.count() ), "Uploaded points count is not equal to the dataframe count" -def test_upsert_multi_vector( - qdrant: Qdrant, spark_session: SparkSession -): + +def test_upsert_multi_vector(qdrant: Qdrant, spark_session: SparkSession): df = ( spark_session.read.schema(schema) .option("multiline", "true") @@ -212,12 +243,16 @@ def test_upsert_multi_vector( "multi_vector_names": "multi", "schema": df.schema.json(), "api_key": qdrant.api_key, + "shard_key_selector": INTEGER_SHARD_KEY, } df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save() assert ( - qdrant.client.count(qdrant.collection_name).count == df.count() + qdrant.client.count( + qdrant.collection_name, shard_key_selector=INTEGER_SHARD_KEY + ).count + == df.count() ), "Uploaded points count is not equal to the dataframe count" @@ -233,11 +268,15 @@ def test_upsert_without_vectors(qdrant: Qdrant, spark_session: SparkSession): "collection_name": qdrant.collection_name, "schema": df.schema.json(), "api_key": qdrant.api_key, + "shard_key_selector": INTEGER_SHARD_KEY, } df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save() assert ( - qdrant.client.count(qdrant.collection_name).count == df.count() + qdrant.client.count( + qdrant.collection_name, shard_key_selector=INTEGER_SHARD_KEY + ).count + == df.count() ), "Uploaded points count is not equal to the dataframe count" @@ -256,7 +295,27 @@ def test_custom_id_field(qdrant: Qdrant, spark_session: SparkSession): "id_field": "id", "schema": df.schema.json(), "api_key": qdrant.api_key, + "shard_key_selector": f"{STRING_SHARD_KEY},{INTEGER_SHARD_KEY}", } df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save() - assert len(qdrant.client.retrieve(qdrant.collection_name, [1, 2, 3, 15, 18])) == 5 + assert ( + len( + qdrant.client.retrieve( + qdrant.collection_name, + [1, 2, 3, 15, 18], + shard_key_selector=INTEGER_SHARD_KEY, + ) + ) + == 5 + ) + assert ( + len( + qdrant.client.retrieve( + qdrant.collection_name, + [1, 2, 3, 15, 18], + shard_key_selector=STRING_SHARD_KEY, + ) + ) + == 5 + )