Skip to content

Commit

Permalink
test: Integration tests for upserting with shard keys (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 authored Sep 13, 2024
1 parent 24b09ef commit f3b0bff
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 19 deletions.
26 changes: 22 additions & 4 deletions src/test/python/conftest.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import pytest
from testcontainers.qdrant import QdrantContainer
from testcontainers.core.container import DockerContainer
from qdrant_client import QdrantClient, models
import uuid
from pyspark.sql import SparkSession
from typing import NamedTuple
from uuid import uuid4
from testcontainers.core.waiting_utils import wait_for_logs


QDRANT_GRPC_PORT = 6334
QDRANT_EMBEDDING_DIM = 6
QDRANT_DISTANCE = models.Distance.COSINE
QDRANT_API_KEY = uuid4().hex
STRING_SHARD_KEY = "string_shard_key"
INTEGER_SHARD_KEY = 876


class Qdrant(NamedTuple):
Expand All @@ -20,7 +23,15 @@ class Qdrant(NamedTuple):
client: QdrantClient


qdrant_container = QdrantContainer(image="qdrant/qdrant:latest", api_key=QDRANT_API_KEY)
qdrant_container = (
(
DockerContainer(image="qdrant/qdrant:latest")
.with_env("QDRANT__SERVICE__API_KEY", QDRANT_API_KEY)
.with_env("QDRANT__CLUSTER__ENABLED", "true")
)
.with_command("./qdrant --uri http://qdrant_node_1:6335")
.with_exposed_ports(QDRANT_GRPC_PORT)
)


# Reference: https://gist.github.com/dizzythinks/f3bb37fd8ab1484bfec79d39ad8a92d3
Expand All @@ -38,6 +49,7 @@ def get_pom_version():
@pytest.fixture(scope="module", autouse=True)
def setup_container(request):
qdrant_container.start()
wait_for_logs(qdrant_container, "Qdrant gRPC listening on 6334")

def remove_container():
qdrant_container.stop()
Expand Down Expand Up @@ -92,14 +104,20 @@ def qdrant():
"multi": models.VectorParams(
size=QDRANT_EMBEDDING_DIM,
distance=QDRANT_DISTANCE,
multivector_config=models.MultiVectorConfig(comparator=models.MultiVectorComparator.MAX_SIM)
)
multivector_config=models.MultiVectorConfig(
comparator=models.MultiVectorComparator.MAX_SIM
),
),
},
sparse_vectors_config={
"sparse": models.SparseVectorParams(),
"another_sparse": models.SparseVectorParams(),
},
sharding_method=models.ShardingMethod.CUSTOM,
)

client.create_shard_key(collection_name, STRING_SHARD_KEY)
client.create_shard_key(collection_name, INTEGER_SHARD_KEY)

yield Qdrant(
url=f"http://{host}:{grpc_port}",
Expand Down
89 changes: 74 additions & 15 deletions src/test/python/test_qdrant_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pyspark.sql import SparkSession

from .schema import schema
from .conftest import Qdrant
from .conftest import Qdrant, STRING_SHARD_KEY, INTEGER_SHARD_KEY

current_directory = os.path.dirname(__file__)
input_file_path = os.path.join(current_directory, "..", "resources", "users.json")
Expand All @@ -20,12 +20,16 @@ def test_upsert_unnamed_vectors(qdrant: Qdrant, spark_session: SparkSession):
"embedding_field": "dense_vector",
"api_key": qdrant.api_key,
"schema": df.schema.json(),
"shard_key_selector": STRING_SHARD_KEY,
}

df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save()

assert (
qdrant.client.count(qdrant.collection_name).count == df.count()
qdrant.client.count(
qdrant.collection_name, shard_key_selector=STRING_SHARD_KEY
).count
== df.count()
), "Uploaded points count is not equal to the dataframe count"


Expand All @@ -41,12 +45,16 @@ def test_upsert_named_vectors(qdrant: Qdrant, spark_session: SparkSession):
"vector_name": "dense",
"schema": df.schema.json(),
"api_key": qdrant.api_key,
"shard_key_selector": STRING_SHARD_KEY,
}

df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save()

assert (
qdrant.client.count(qdrant.collection_name).count == df.count()
qdrant.client.count(
qdrant.collection_name, shard_key_selector=STRING_SHARD_KEY
).count
== df.count()
), "Uploaded points count is not equal to the dataframe count"


Expand All @@ -65,12 +73,16 @@ def test_upsert_multiple_named_dense_vectors(
"vector_names": "dense,another_dense",
"schema": df.schema.json(),
"api_key": qdrant.api_key,
"shard_key_selector": STRING_SHARD_KEY,
}

df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save()

assert (
qdrant.client.count(qdrant.collection_name).count == df.count()
qdrant.client.count(
qdrant.collection_name, shard_key_selector=STRING_SHARD_KEY
).count
== df.count()
), "Uploaded points count is not equal to the dataframe count"


Expand All @@ -88,12 +100,16 @@ def test_upsert_sparse_vectors(qdrant: Qdrant, spark_session: SparkSession):
"sparse_vector_names": "sparse",
"schema": df.schema.json(),
"api_key": qdrant.api_key,
"shard_key_selector": STRING_SHARD_KEY,
}

df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save()

assert (
qdrant.client.count(qdrant.collection_name).count == df.count()
qdrant.client.count(
qdrant.collection_name, shard_key_selector=STRING_SHARD_KEY
).count
== df.count()
), "Uploaded points count is not equal to the dataframe count"


Expand All @@ -111,12 +127,16 @@ def test_upsert_multiple_sparse_vectors(qdrant: Qdrant, spark_session: SparkSess
"sparse_vector_names": "sparse,another_sparse",
"schema": df.schema.json(),
"api_key": qdrant.api_key,
"shard_key_selector": STRING_SHARD_KEY,
}

df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save()

assert (
qdrant.client.count(qdrant.collection_name).count == df.count()
qdrant.client.count(
qdrant.collection_name, shard_key_selector=STRING_SHARD_KEY
).count
== df.count()
), "Uploaded points count is not equal to the dataframe count"


Expand All @@ -136,12 +156,16 @@ def test_upsert_sparse_named_dense_vectors(qdrant: Qdrant, spark_session: SparkS
"sparse_vector_names": "sparse",
"schema": df.schema.json(),
"api_key": qdrant.api_key,
"shard_key_selector": STRING_SHARD_KEY,
}

df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save()

assert (
qdrant.client.count(qdrant.collection_name).count == df.count()
qdrant.client.count(
qdrant.collection_name, shard_key_selector=STRING_SHARD_KEY
).count
== df.count()
), "Uploaded points count is not equal to the dataframe count"


Expand All @@ -162,12 +186,16 @@ def test_upsert_sparse_unnamed_dense_vectors(
"sparse_vector_names": "sparse",
"schema": df.schema.json(),
"api_key": qdrant.api_key,
"shard_key_selector": INTEGER_SHARD_KEY,
}

df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save()

assert (
qdrant.client.count(qdrant.collection_name).count == df.count()
qdrant.client.count(
qdrant.collection_name, shard_key_selector=INTEGER_SHARD_KEY
).count
== df.count()
), "Uploaded points count is not equal to the dataframe count"


Expand All @@ -189,17 +217,20 @@ def test_upsert_multiple_sparse_dense_vectors(
"sparse_vector_names": "sparse,another_sparse",
"schema": df.schema.json(),
"api_key": qdrant.api_key,
"shard_key_selector": INTEGER_SHARD_KEY,
}

df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save()

assert (
qdrant.client.count(qdrant.collection_name).count == df.count()
qdrant.client.count(
qdrant.collection_name, shard_key_selector=INTEGER_SHARD_KEY
).count
== df.count()
), "Uploaded points count is not equal to the dataframe count"

def test_upsert_multi_vector(
qdrant: Qdrant, spark_session: SparkSession
):

def test_upsert_multi_vector(qdrant: Qdrant, spark_session: SparkSession):
df = (
spark_session.read.schema(schema)
.option("multiline", "true")
Expand All @@ -212,12 +243,16 @@ def test_upsert_multi_vector(
"multi_vector_names": "multi",
"schema": df.schema.json(),
"api_key": qdrant.api_key,
"shard_key_selector": INTEGER_SHARD_KEY,
}

df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save()

assert (
qdrant.client.count(qdrant.collection_name).count == df.count()
qdrant.client.count(
qdrant.collection_name, shard_key_selector=INTEGER_SHARD_KEY
).count
== df.count()
), "Uploaded points count is not equal to the dataframe count"


Expand All @@ -233,11 +268,15 @@ def test_upsert_without_vectors(qdrant: Qdrant, spark_session: SparkSession):
"collection_name": qdrant.collection_name,
"schema": df.schema.json(),
"api_key": qdrant.api_key,
"shard_key_selector": INTEGER_SHARD_KEY,
}
df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save()

assert (
qdrant.client.count(qdrant.collection_name).count == df.count()
qdrant.client.count(
qdrant.collection_name, shard_key_selector=INTEGER_SHARD_KEY
).count
== df.count()
), "Uploaded points count is not equal to the dataframe count"


Expand All @@ -256,7 +295,27 @@ def test_custom_id_field(qdrant: Qdrant, spark_session: SparkSession):
"id_field": "id",
"schema": df.schema.json(),
"api_key": qdrant.api_key,
"shard_key_selector": f"{STRING_SHARD_KEY},{INTEGER_SHARD_KEY}",
}
df.write.format("io.qdrant.spark.Qdrant").options(**opts).mode("append").save()

assert len(qdrant.client.retrieve(qdrant.collection_name, [1, 2, 3, 15, 18])) == 5
assert (
len(
qdrant.client.retrieve(
qdrant.collection_name,
[1, 2, 3, 15, 18],
shard_key_selector=INTEGER_SHARD_KEY,
)
)
== 5
)
assert (
len(
qdrant.client.retrieve(
qdrant.collection_name,
[1, 2, 3, 15, 18],
shard_key_selector=STRING_SHARD_KEY,
)
)
== 5
)

0 comments on commit f3b0bff

Please sign in to comment.