Skip to content

Commit

Permalink
Pass auth provider instance
Browse files Browse the repository at this point in the history
  • Loading branch information
tellet-q committed Dec 6, 2024
1 parent dd42ab9 commit 00d63ff
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 17 deletions.
4 changes: 3 additions & 1 deletion qdrant_client/async_qdrant_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ def __init__(
self._closed: bool = False
if check_compatibility:
client_version = importlib.metadata.version("qdrant-client")
server_version = get_server_version(self.rest_uri, self._rest_headers)
server_version = get_server_version(
self.rest_uri, self._rest_headers, self._rest_args.get("auth")
)
if not is_versions_compatible(client_version, server_version):
warnings.warn(
f"Qdrant client version {client_version} is incompatible with server version {server_version}. Major versions should match and minor version difference must not exceed 1. Set check_version=False to skip version check."
Expand Down
22 changes: 12 additions & 10 deletions qdrant_client/common/version_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,29 @@
from collections import namedtuple

import httpx
from pydantic import ValidationError

from qdrant_client.http.api_client import parse_as_type
from qdrant_client.http.models import models
from qdrant_client.auth import BearerAuth

Version = namedtuple("Version", ["major", "minor", "rest"])


def get_server_version(rest_uri: str, rest_headers: Dict[str, Any]) -> Union[str, None]:
def get_server_version(
rest_uri: str, rest_headers: Dict[str, Any], auth_provider: Union[BearerAuth, None]
) -> Union[str, None]:
try:
response = httpx.get(rest_uri + "/", headers=rest_headers)
if auth_provider:
response = httpx.get(rest_uri + "/", headers=rest_headers, auth=auth_provider)
else:
response = httpx.get(rest_uri + "/", headers=rest_headers)
except Exception as er:
warnings.warn(f"Unable to get server version: {er}, default to None")
return None

if response.status_code in [200, 201, 202]:
try:
version_info = parse_as_type(response.json(), models.VersionInfo)
return version_info.version
except ValidationError as e:
if response.status_code == 200:
version_info = response.json().get("version", None)
if not version_info:
warnings.warn(f"Unable to parse response from server: {response}, default to None")
return version_info
else:
warnings.warn(f"Unexpected response from server: {response}, default to None")
return None
Expand Down
4 changes: 3 additions & 1 deletion qdrant_client/qdrant_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,9 @@ def __init__(

if check_compatibility:
client_version = importlib.metadata.version("qdrant-client")
server_version = get_server_version(self.rest_uri, self._rest_headers)
server_version = get_server_version(
self.rest_uri, self._rest_headers, self._rest_args.get("auth")
)
if not is_versions_compatible(client_version, server_version):
warnings.warn(
f"Qdrant client version {client_version} is incompatible with server version {server_version}. Major versions should match and minor version difference must not exceed 1. Set check_version=False to skip version check."
Expand Down
11 changes: 6 additions & 5 deletions tests/test_async_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,12 +588,13 @@ def auth_token_provider():
call_num += 1
return sync_token

# Additional sync request is sent during client init to check compatibility
client = AsyncQdrantClient(timeout=3, auth_token_provider=auth_token_provider)
await client.get_collections()
assert sync_token == "token_0"
assert sync_token == "token_1"

await client.get_collections()
assert sync_token == "token_1"
assert sync_token == "token_2"

sync_token = ""
call_num = 0
Expand All @@ -602,13 +603,13 @@ def auth_token_provider():
prefer_grpc=True, timeout=3, auth_token_provider=auth_token_provider
)
await client.get_collections()
assert sync_token == "token_0"
assert sync_token == "token_1"

await client.get_collections()
assert sync_token == "token_1"
assert sync_token == "token_2"

await client.unlock_storage()
assert sync_token == "token_2"
assert sync_token == "token_3"


@pytest.mark.asyncio
Expand Down

0 comments on commit 00d63ff

Please sign in to comment.