From 00d63ffb90056e65be931b9dc469217e7fbc33ec Mon Sep 17 00:00:00 2001 From: tellet-q Date: Fri, 6 Dec 2024 13:36:50 +0100 Subject: [PATCH] Pass auth provider instance --- qdrant_client/async_qdrant_remote.py | 4 +++- qdrant_client/common/version_check.py | 22 ++++++++++++---------- qdrant_client/qdrant_remote.py | 4 +++- tests/test_async_qdrant_client.py | 11 ++++++----- 4 files changed, 24 insertions(+), 17 deletions(-) diff --git a/qdrant_client/async_qdrant_remote.py b/qdrant_client/async_qdrant_remote.py index 8294726d..da6c5436 100644 --- a/qdrant_client/async_qdrant_remote.py +++ b/qdrant_client/async_qdrant_remote.py @@ -161,7 +161,9 @@ def __init__( self._closed: bool = False if check_compatibility: client_version = importlib.metadata.version("qdrant-client") - server_version = get_server_version(self.rest_uri, self._rest_headers) + server_version = get_server_version( + self.rest_uri, self._rest_headers, self._rest_args.get("auth") + ) if not is_versions_compatible(client_version, server_version): warnings.warn( f"Qdrant client version {client_version} is incompatible with server version {server_version}. Major versions should match and minor version difference must not exceed 1. Set check_version=False to skip version check." diff --git a/qdrant_client/common/version_check.py b/qdrant_client/common/version_check.py index 67fa717a..e1f94644 100644 --- a/qdrant_client/common/version_check.py +++ b/qdrant_client/common/version_check.py @@ -3,27 +3,29 @@ from collections import namedtuple import httpx -from pydantic import ValidationError -from qdrant_client.http.api_client import parse_as_type -from qdrant_client.http.models import models +from qdrant_client.auth import BearerAuth Version = namedtuple("Version", ["major", "minor", "rest"]) -def get_server_version(rest_uri: str, rest_headers: Dict[str, Any]) -> Union[str, None]: +def get_server_version( + rest_uri: str, rest_headers: Dict[str, Any], auth_provider: Union[BearerAuth, None] +) -> Union[str, None]: try: - response = httpx.get(rest_uri + "/", headers=rest_headers) + if auth_provider: + response = httpx.get(rest_uri + "/", headers=rest_headers, auth=auth_provider) + else: + response = httpx.get(rest_uri + "/", headers=rest_headers) except Exception as er: warnings.warn(f"Unable to get server version: {er}, default to None") return None - if response.status_code in [200, 201, 202]: - try: - version_info = parse_as_type(response.json(), models.VersionInfo) - return version_info.version - except ValidationError as e: + if response.status_code == 200: + version_info = response.json().get("version", None) + if not version_info: warnings.warn(f"Unable to parse response from server: {response}, default to None") + return version_info else: warnings.warn(f"Unexpected response from server: {response}, default to None") return None diff --git a/qdrant_client/qdrant_remote.py b/qdrant_client/qdrant_remote.py index 791ba4b7..7809d622 100644 --- a/qdrant_client/qdrant_remote.py +++ b/qdrant_client/qdrant_remote.py @@ -198,7 +198,9 @@ def __init__( if check_compatibility: client_version = importlib.metadata.version("qdrant-client") - server_version = get_server_version(self.rest_uri, self._rest_headers) + server_version = get_server_version( + self.rest_uri, self._rest_headers, self._rest_args.get("auth") + ) if not is_versions_compatible(client_version, server_version): warnings.warn( f"Qdrant client version {client_version} is incompatible with server version {server_version}. Major versions should match and minor version difference must not exceed 1. Set check_version=False to skip version check." diff --git a/tests/test_async_qdrant_client.py b/tests/test_async_qdrant_client.py index 4174f6fd..837830c8 100644 --- a/tests/test_async_qdrant_client.py +++ b/tests/test_async_qdrant_client.py @@ -588,12 +588,13 @@ def auth_token_provider(): call_num += 1 return sync_token + # Additional sync request is sent during client init to check compatibility client = AsyncQdrantClient(timeout=3, auth_token_provider=auth_token_provider) await client.get_collections() - assert sync_token == "token_0" + assert sync_token == "token_1" await client.get_collections() - assert sync_token == "token_1" + assert sync_token == "token_2" sync_token = "" call_num = 0 @@ -602,13 +603,13 @@ def auth_token_provider(): prefer_grpc=True, timeout=3, auth_token_provider=auth_token_provider ) await client.get_collections() - assert sync_token == "token_0" + assert sync_token == "token_1" await client.get_collections() - assert sync_token == "token_1" + assert sync_token == "token_2" await client.unlock_storage() - assert sync_token == "token_2" + assert sync_token == "token_3" @pytest.mark.asyncio