Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: feat: incompatible versions warning on init #858

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions qdrant_client/async_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
Union[Callable[[], str], Callable[[], Awaitable[str]]]
] = None,
cloud_inference: bool = False,
check_version: Optional[bool] = None,
**kwargs: Any,
):
self._inference_inspector = Inspector()
Expand Down Expand Up @@ -132,6 +133,7 @@ def __init__(
host=host,
grpc_options=grpc_options,
auth_token_provider=auth_token_provider,
check_version=check_version,
**kwargs,
)
if isinstance(self._client, AsyncQdrantLocal) and cloud_inference:
Expand Down
27 changes: 26 additions & 1 deletion qdrant_client/async_qdrant_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
# ```
#
# ****** WARNING: THIS FILE IS AUTOGENERATED ******

import asyncio
import importlib.metadata
import logging
import math
import warnings
Expand All @@ -28,6 +29,7 @@
import httpx
import numpy as np
from grpc import Compression
from packaging import version
from urllib3.util import Url, parse_url
from qdrant_client import grpc as grpc
from qdrant_client._pydantic_compat import construct
Expand Down Expand Up @@ -66,6 +68,7 @@ def __init__(
auth_token_provider: Optional[
Union[Callable[[], str], Callable[[], Awaitable[str]]]
] = None,
check_version: Optional[bool] = None,
**kwargs: Any,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -150,6 +153,28 @@ def __init__(
self._grpc_snapshots_client: Optional[grpc.SnapshotsStub] = None
self._grpc_root_client: Optional[grpc.QdrantStub] = None
self._closed: bool = False
if check_version:
client_version = importlib.metadata.version("qdrant-client")
loop = asyncio.get_event_loop()
server_version = loop.run_until_complete(self.info()).version
is_ok = self._check_versions(client_version, server_version)
if not is_ok:
warnings.warn(
f"Found Qdrant server version `{server_version}` is not supported by current version of Qdrant client `{client_version}`."
)

@staticmethod
def _check_versions(client_version: str, server_version: str) -> bool:
client = version.parse(client_version)
server = version.parse(server_version)
if client_version == server_version:
return True
major_dif = abs(server.major - client.major)
if major_dif >= 1:
return False
elif major_dif == 0:
return abs(server.minor - client.minor) <= 1
return False

@property
def closed(self) -> bool:
Expand Down
2 changes: 2 additions & 0 deletions qdrant_client/qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
Union[Callable[[], str], Callable[[], Awaitable[str]]]
] = None,
cloud_inference: bool = False,
check_version: Optional[bool] = None,
**kwargs: Any,
):
self._inference_inspector = Inspector()
Expand Down Expand Up @@ -144,6 +145,7 @@ def __init__(
host=host,
grpc_options=grpc_options,
auth_token_provider=auth_token_provider,
check_version=check_version,
**kwargs,
)

Expand Down
25 changes: 25 additions & 0 deletions qdrant_client/qdrant_remote.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import importlib.metadata
import logging
import math
import warnings
Expand All @@ -19,6 +20,7 @@
import httpx
import numpy as np
from grpc import Compression
from packaging import version
from urllib3.util import Url, parse_url

from qdrant_client import grpc as grpc
Expand Down Expand Up @@ -58,6 +60,7 @@ def __init__(
auth_token_provider: Optional[
Union[Callable[[], str], Callable[[], Awaitable[str]]]
] = None,
check_version: Optional[bool] = None,
**kwargs: Any,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -186,6 +189,28 @@ def __init__(

self._closed: bool = False

if check_version:
client_version = importlib.metadata.version("qdrant-client")
server_version = self.info().version
is_ok = self._check_versions(client_version, server_version)
if not is_ok:
warnings.warn(
f"Found Qdrant server version `{server_version}` is not supported by current version of Qdrant client `{client_version}`."
)

@staticmethod
def _check_versions(client_version: str, server_version: str) -> bool:
client = version.parse(client_version)
server = version.parse(server_version)
if client_version == server_version:
return True
major_dif = abs(server.major - client.major)
if major_dif >= 1:
return False
elif major_dif == 0:
return abs(server.minor - client.minor) <= 1
return False

@property
def closed(self) -> bool:
return self._closed
Expand Down
54 changes: 54 additions & 0 deletions tests/test_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ def test_client_init():
assert isinstance(client._client, QdrantRemote)
assert client._client.rest_uri == "http://localhost:6333"

client = QdrantClient(":memory:", check_version=True)
assert isinstance(client._client, QdrantLocal)

client = QdrantClient(check_version=True)
assert isinstance(client._client, QdrantRemote)

client = QdrantClient(https=True)
assert isinstance(client._client, QdrantRemote)
assert client._client.rest_uri == "https://localhost:6333"
Expand Down Expand Up @@ -183,6 +189,54 @@ def test_client_init():
assert client.init_options["metadata"] == {"some-rest-meta": "some-value"}


@pytest.mark.parametrize(
"test_data",
[
("1.9.3.dev0", "2.0.1", False, "Diff between major versions = 1, minor versions differ"),
("1.9.0", "2.9.0", False, "Diff between major versions = 1, minor versions are the same"),
(
"1.1.0",
"1.2.9",
True,
"Diff between major versions == 0, diff between minor versions == 1 (server > client)",
),
(
"1.2.7",
"1.1.8.dev0",
True,
"Diff between major versions == 0, diff between minor versions == 1 (client > server)",
),
(
"1.2.1",
"1.2.29",
True,
"Diff between major versions == 0, diff between minor versions == 0",
),
("1.2.0", "1.2.0", True, "Same versions"),
(
"1.2.0",
"1.4.0",
False,
"Diff between major versions == 0, diff between minor versions > 1 (server > client)",
),
(
"1.4.0",
"1.2.0",
False,
"Diff between major versions == 0, diff between minor versions > 1 (client > server)",
),
("1.9.0", "3.7.0", False, "Diff between major versions > 1 (server > client)"),
("3.0.0", "1.0.0", False, "Diff between major versions > 1 (client > server)"),
],
)
def test_check_versions(test_data):
client = QdrantClient()
assert (
client._client._check_versions(client_version=test_data[0], server_version=test_data[1])
is test_data[2]
)


@pytest.mark.parametrize("prefer_grpc", [False, True])
@pytest.mark.parametrize("parallel", [1, 2])
def test_records_upload(prefer_grpc, parallel):
Expand Down
Loading