From 0c819b1a1d9a9d8fdda89853a429ed42b239fe7a Mon Sep 17 00:00:00 2001 From: tellet-q Date: Thu, 5 Dec 2024 15:51:30 +0100 Subject: [PATCH] Address review --- qdrant_client/async_qdrant_client.py | 4 +- qdrant_client/async_qdrant_remote.py | 15 ++-- qdrant_client/common/version_check.py | 52 ++++++------ qdrant_client/qdrant_client.py | 4 +- qdrant_client/qdrant_remote.py | 15 ++-- tests/test_common.py | 111 ++++++++++++-------------- tests/test_qdrant_client.py | 6 +- 7 files changed, 99 insertions(+), 108 deletions(-) diff --git a/qdrant_client/async_qdrant_client.py b/qdrant_client/async_qdrant_client.py index 2c08ec01..73c9c197 100644 --- a/qdrant_client/async_qdrant_client.py +++ b/qdrant_client/async_qdrant_client.py @@ -94,7 +94,7 @@ def __init__( Union[Callable[[], str], Callable[[], Awaitable[str]]] ] = None, cloud_inference: bool = False, - check_version: Optional[bool] = None, + check_compatibility: Optional[bool] = True, **kwargs: Any, ): self._inference_inspector = Inspector() @@ -133,7 +133,7 @@ def __init__( host=host, grpc_options=grpc_options, auth_token_provider=auth_token_provider, - check_version=check_version, + check_compatibility=check_compatibility, **kwargs, ) if isinstance(self._client, AsyncQdrantLocal) and cloud_inference: diff --git a/qdrant_client/async_qdrant_remote.py b/qdrant_client/async_qdrant_remote.py index d3ea482a..942f301f 100644 --- a/qdrant_client/async_qdrant_remote.py +++ b/qdrant_client/async_qdrant_remote.py @@ -35,7 +35,7 @@ from qdrant_client._pydantic_compat import construct from qdrant_client.auth import BearerAuth from qdrant_client.async_client_base import AsyncQdrantBase -from qdrant_client.common.version_check import is_server_version_compatible +from qdrant_client.common.version_check import is_versions_compatible, get_server_version from qdrant_client.connection import get_async_channel as get_channel from qdrant_client.conversions import common_types as types from qdrant_client.conversions.common_types import get_args_subscribed @@ -69,7 +69,7 @@ def __init__( auth_token_provider: Optional[ Union[Callable[[], str], Callable[[], Awaitable[str]]] ] = None, - check_version: Optional[bool] = None, + check_compatibility: Optional[bool] = True, **kwargs: Any, ): super().__init__(**kwargs) @@ -159,10 +159,13 @@ def __init__( self._grpc_snapshots_client: Optional[grpc.SnapshotsStub] = None self._grpc_root_client: Optional[grpc.QdrantStub] = None self._closed: bool = False - if check_version and (not is_server_version_compatible(self.rest_uri, **self._rest_args)): - warnings.warn( - "Qdrant client version may be incompatible with server version. Set check_version=False to skip version check." - ) + if check_compatibility: + client_version = importlib.metadata.version("qdrant-client") + server_version = get_server_version(self.rest_uri, **self._rest_args) + if not is_versions_compatible(client_version, server_version): + warnings.warn( + f"Qdrant client version {client_version} is incompatible with server version {server_version}. Major versions should mathc and minor version difference must not exceed 1. Set check_version=False to skip version check." + ) @property def closed(self) -> bool: diff --git a/qdrant_client/common/version_check.py b/qdrant_client/common/version_check.py index a09af624..c9c9ca53 100644 --- a/qdrant_client/common/version_check.py +++ b/qdrant_client/common/version_check.py @@ -1,41 +1,32 @@ -import importlib.metadata import logging +import warnings from typing import Union, Any from collections import namedtuple from qdrant_client.http import SyncApis, ApiClient -from qdrant_client.http.models import models Version = namedtuple("Version", ["major", "minor", "rest"]) -def is_server_version_compatible(rest_uri: str, **kwargs: Any) -> bool: - def get_server_info() -> Any: +def get_server_version(rest_uri: str, **kwargs: Any) -> Union[str, None]: + try: openapi_client: SyncApis[ApiClient] = SyncApis( host=rest_uri, **kwargs, ) - return openapi_client.client.request( - type_=models.VersionInfo, - method="GET", - url="/", - headers=None, - ) + version_info = openapi_client.service_api.root() - def get_server_version() -> Union[str, None]: try: - version_info = get_server_info() - except Exception as er: - logging.warning(f"Unable to get server version: {er}, default to None") - return None + openapi_client.close() + except Exception: + logging.warning( + "Unable to close http connection. Connection was interrupted on the server side" + ) - if not version_info: - return None return version_info.version - - client_version = importlib.metadata.version("qdrant-client") - server_version = get_server_version() - return compare_versions(client_version, server_version) + except Exception as er: + warnings.warn(f"Unable to get server version: {er}, default to None") + return None def parse_version(version: str) -> Version: @@ -50,9 +41,15 @@ def parse_version(version: str) -> Version: ) from er -def compare_versions(client_version: Union[str, None], server_version: Union[str, None]) -> bool: - if not client_version or not server_version: - logging.warning(f"Unable to compare: {client_version} vs {server_version}") +def is_versions_compatible( + client_version: Union[str, None], server_version: Union[str, None] +) -> bool: + if not client_version: + warnings.warn(f"Unable to compare with client version {client_version}") + return False + + if not server_version: + warnings.warn(f"Unable to compare with server version {server_version}") return False if client_version == server_version: @@ -62,11 +59,10 @@ def compare_versions(client_version: Union[str, None], server_version: Union[str parsed_server_version = parse_version(server_version) parsed_client_version = parse_version(client_version) except ValueError as er: - logging.warning(f"Unable to parse version: {er}") + warnings.warn(f"Unable to compare versions: {er}") return False + major_dif = abs(parsed_server_version.major - parsed_client_version.major) if major_dif >= 1: return False - elif major_dif == 0: - return abs(parsed_server_version.minor - parsed_client_version.minor) <= 1 - return False + return abs(parsed_server_version.minor - parsed_client_version.minor) <= 1 diff --git a/qdrant_client/qdrant_client.py b/qdrant_client/qdrant_client.py index a0ec8bd4..d195d917 100644 --- a/qdrant_client/qdrant_client.py +++ b/qdrant_client/qdrant_client.py @@ -94,7 +94,7 @@ def __init__( Union[Callable[[], str], Callable[[], Awaitable[str]]] ] = None, cloud_inference: bool = False, - check_version: Optional[bool] = None, + check_compatibility: Optional[bool] = True, **kwargs: Any, ): self._inference_inspector = Inspector() @@ -145,7 +145,7 @@ def __init__( host=host, grpc_options=grpc_options, auth_token_provider=auth_token_provider, - check_version=check_version, + check_compatibility=check_compatibility, **kwargs, ) diff --git a/qdrant_client/qdrant_remote.py b/qdrant_client/qdrant_remote.py index 9930ee13..81cb7277 100644 --- a/qdrant_client/qdrant_remote.py +++ b/qdrant_client/qdrant_remote.py @@ -27,7 +27,7 @@ from qdrant_client._pydantic_compat import construct from qdrant_client.auth import BearerAuth from qdrant_client.client_base import QdrantBase -from qdrant_client.common.version_check import is_server_version_compatible +from qdrant_client.common.version_check import is_versions_compatible, get_server_version from qdrant_client.connection import get_async_channel, get_channel from qdrant_client.conversions import common_types as types from qdrant_client.conversions.common_types import get_args_subscribed @@ -61,7 +61,7 @@ def __init__( auth_token_provider: Optional[ Union[Callable[[], str], Callable[[], Awaitable[str]]] ] = None, - check_version: Optional[bool] = None, + check_compatibility: Optional[bool] = True, **kwargs: Any, ): super().__init__(**kwargs) @@ -196,10 +196,13 @@ def __init__( self._closed: bool = False - if check_version and not is_server_version_compatible(self.rest_uri, **self._rest_args): - warnings.warn( - "Qdrant client version may be incompatible with server version. Set check_version=False to skip version check." - ) + if check_compatibility: + client_version = importlib.metadata.version("qdrant-client") + server_version = get_server_version(self.rest_uri, **self._rest_args) + if not is_versions_compatible(client_version, server_version): + warnings.warn( + f"Qdrant client version {client_version} is incompatible with server version {server_version}. Major versions should mathc and minor version difference must not exceed 1. Set check_version=False to skip version check." + ) @property def closed(self) -> bool: diff --git a/tests/test_common.py b/tests/test_common.py index d6026eba..801e2692 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,77 +1,66 @@ import pytest -from qdrant_client.common.version_check import compare_versions, parse_version +from qdrant_client.common.version_check import is_versions_compatible, parse_version @pytest.mark.parametrize( - "test_data", + "client_version, server_version, expected_result", [ - ("1.9.3.dev0", "2.0.1", False, "Diff between major versions = 1, minor versions differ"), - ( - "1.9", - "2.0", - False, - "Diff between major versions = 1, minor versions differ, only major and patch", - ), - ("1", "2", False, "Diff between major versions = 1, minor versions differ, only major"), - ("1.9.0", "2.9.0", False, "Diff between major versions = 1, minor versions are the same"), - ( - "1.1.0", - "1.2.9", - True, - "Diff between major versions == 0, diff between minor versions == 1 (server > client)", - ), - ( - "1.2.7", - "1.1.8.dev0", - True, - "Diff between major versions == 0, diff between minor versions == 1 (client > server)", - ), - ( - "1.2.1", - "1.2.29", - True, - "Diff between major versions == 0, diff between minor versions == 0", - ), - ("1.2.0", "1.2.0", True, "Same versions"), - ( - "1.2.0", - "1.4.0", - False, - "Diff between major versions == 0, diff between minor versions > 1 (server > client)", - ), - ( - "1.4.0", - "1.2.0", - False, - "Diff between major versions == 0, diff between minor versions > 1 (client > server)", - ), - ("1.9.0", "3.7.0", False, "Diff between major versions > 1 (server > client)"), - ("3.0.0", "1.0.0", False, "Diff between major versions > 1 (client > server)"), - (None, "1.0.0", False, "Client version is None"), - ("1.0.0", None, False, "Server version is None"), - (None, None, False, "Both versions are None"), + ("1.9.3.dev0", "2.8.1.dev12-something", False), + ("1.9", "2.8", False), + ("1", "2", False), + ("1.9.0", "2.9.0", False), + ("1.1.0", "1.2.9", True), + ("1.2.7", "1.1.8.dev0", True), + ("1.2.1", "1.2.29", True), + ("1.2.0", "1.2.0", True), + ("1.2.0", "1.4.0", False), + ("1.4.0", "1.2.0", False), + ("1.9.0", "3.7.0", False), + ("3.0.0", "1.0.0", False), + (None, "1.0.0", False), + ("1.0.0", None, False), + (None, None, False), + ], + ids=[ + "Diff between major versions = 1, negative", + "Diff between major versions = 1, only major and minor, negative", + "Diff between major versions = 1, only major, negative", + "Diff between major versions = 1, minor versions are the same, negative", + "Diff between major versions == 0, diff between minor versions == 1 (server > client), positive", + "Diff between major versions == 0, diff between minor versions == 1 (client > server), positive", + "Diff between major versions == 0, diff between minor versions == 0, positive", + "Same versions, positive", + "Diff between major versions == 0, diff between minor versions > 1 (server > client), negative", + "Diff between major versions == 0, diff between minor versions > 1 (client > server), negative", + "Diff between major versions > 1 (server > client), negative", + "Diff between major versions > 1 (client > server), negative", + "Client version is None, negative", + "Server version is None, negative", + "Both versions are None, negative", ], ) -def test_check_versions(test_data): +def test_check_versions(client_version, server_version, expected_result): assert ( - compare_versions(client_version=test_data[0], server_version=test_data[1]) is test_data[2] + is_versions_compatible(client_version=client_version, server_version=server_version) + is expected_result ) @pytest.mark.parametrize( - "test_data", - [ - ("1", "Only major version"), - ("1.", "Only major version"), - (".1", "Only minor version"), - (".1.", "Only minor version"), - ("1.None.1", "Minor version is not a number"), - ("None.0.1", "Major version is not a number"), - (None, "Version is None"), - ("", "Version is empty"), + "input_version", + ["1", "1.", ".1", ".1.", "1.None.1", "None.0.1", None, ""], + ids=[ + "Only major part", + "Only major part with dot", + "Only minor part", + "Only minor part with dot", + "Minor part is not a number", + "Major part is not a number", + "Version is None", + "Version is empty", ], ) -def test_parse_versions_value_error(test_data): +def test_parse_versions_value_error(input_version): with pytest.raises(ValueError): - parse_version(test_data[0]) + parse_version(input_version) diff --git a/tests/test_qdrant_client.py b/tests/test_qdrant_client.py index f8956e9c..3ae61c3f 100644 --- a/tests/test_qdrant_client.py +++ b/tests/test_qdrant_client.py @@ -100,13 +100,13 @@ def test_client_init(): assert isinstance(client._client, QdrantRemote) assert client._client.rest_uri == "http://localhost:6333" - client = QdrantClient(":memory:", check_version=True) + client = QdrantClient(":memory:", check_compatibility=True) assert isinstance(client._client, QdrantLocal) - client = QdrantClient(check_version=True) + client = QdrantClient(check_compatibility=True) assert isinstance(client._client, QdrantRemote) - client = QdrantClient(check_version=True, prefer_grpc=True) + client = QdrantClient(check_compatibility=True, prefer_grpc=True) assert isinstance(client._client, QdrantRemote) client = QdrantClient(https=True)