Skip to content

Commit

Permalink
Removed async_grpc (#866)
Browse files Browse the repository at this point in the history
* Removed async_grpc

* Removal of asyncio

* Return needed test

* Proper fix of renaming

---------

Co-authored-by: d.rudenko <[email protected]>
  • Loading branch information
I8dNLo and d.rudenko authored Dec 27, 2024
1 parent 637e2c0 commit 3fbdf04
Show file tree
Hide file tree
Showing 7 changed files with 3 additions and 241 deletions.
34 changes: 0 additions & 34 deletions qdrant_client/qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,40 +192,6 @@ def grpc_points(self) -> grpc.PointsStub:

raise NotImplementedError(f"gRPC client is not supported for {type(self._client)}")

@property
def async_grpc_points(self) -> grpc.PointsStub:
"""gRPC client for points methods
Returns:
An instance of raw gRPC client, generated from Protobuf
"""
warnings.warn(
"async_grpc_points is deprecated and will be removed in a future release. Use `AsyncQdrantRemote.grpc_points` instead.",
DeprecationWarning,
stacklevel=2,
)
if isinstance(self._client, QdrantRemote):
return self._client.async_grpc_points

raise NotImplementedError(f"gRPC client is not supported for {type(self._client)}")

@property
def async_grpc_collections(self) -> grpc.CollectionsStub:
"""gRPC client for collections methods
Returns:
An instance of raw gRPC client, generated from Protobuf
"""
warnings.warn(
"async_grpc_collections is deprecated and will be removed in a future release. Use `AsyncQdrantRemote.grpc_collections` instead.",
DeprecationWarning,
stacklevel=2,
)
if isinstance(self._client, QdrantRemote):
return self._client.async_grpc_collections

raise NotImplementedError(f"gRPC client is not supported for {type(self._client)}")

@property
def rest(self) -> SyncApis[ApiClient]:
"""REST Client
Expand Down
99 changes: 1 addition & 98 deletions qdrant_client/qdrant_remote.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import importlib.metadata
import logging
import math
Expand Down Expand Up @@ -28,7 +27,7 @@
from qdrant_client.auth import BearerAuth
from qdrant_client.client_base import QdrantBase
from qdrant_client.common.version_check import is_compatible, get_server_version
from qdrant_client.connection import get_async_channel, get_channel
from qdrant_client.connection import get_channel
from qdrant_client.conversions import common_types as types
from qdrant_client.conversions.common_types import get_args_subscribed
from qdrant_client.conversions.conversion import (
Expand Down Expand Up @@ -188,7 +187,6 @@ def __init__(
self._grpc_snapshots_client: Optional[grpc.SnapshotsStub] = None
self._grpc_root_client: Optional[grpc.QdrantStub] = None

self._aio_grpc_channel = None
self._aio_grpc_points_client: Optional[grpc.PointsStub] = None
self._aio_grpc_collections_client: Optional[grpc.CollectionsStub] = None
self._aio_grpc_snapshots_client: Optional[grpc.SnapshotsStub] = None
Expand Down Expand Up @@ -223,17 +221,6 @@ def close(self, grpc_grace: Optional[float] = None, **kwargs: Any) -> None:
"Unable to close grpc_channel. Connection was interrupted on the server side"
)

if hasattr(self, "_aio_grpc_channel") and self._aio_grpc_channel is not None:
try:
loop = asyncio.get_running_loop()
loop.create_task(self._aio_grpc_channel.close(grace=grpc_grace))
except AttributeError:
logging.warning(
"Unable to close aio_grpc_channel. Connection was interrupted on the server side"
)
except RuntimeError:
pass

try:
self.openapi_client.close()
except Exception:
Expand Down Expand Up @@ -271,21 +258,6 @@ def _init_grpc_channel(self) -> None:
auth_token_provider=self._auth_token_provider, # type: ignore
)

def _init_async_grpc_channel(self) -> None:
if self._closed:
raise RuntimeError("Client was closed. Please create a new QdrantClient instance.")

if self._aio_grpc_channel is None:
self._aio_grpc_channel = get_async_channel(
host=self._host,
port=self._grpc_port,
ssl=self._https,
metadata=self._grpc_headers,
options=self._grpc_options,
compression=self._grpc_compression,
auth_token_provider=self._auth_token_provider,
)

def _init_grpc_points_client(self) -> None:
self._init_grpc_channel()
self._grpc_points_client = grpc.PointsStub(self._grpc_channel)
Expand All @@ -302,75 +274,6 @@ def _init_grpc_root_client(self) -> None:
self._init_grpc_channel()
self._grpc_root_client = grpc.QdrantStub(self._grpc_channel)

def _init_async_grpc_points_client(self) -> None:
self._init_async_grpc_channel()
self._aio_grpc_points_client = grpc.PointsStub(self._aio_grpc_channel)

def _init_async_grpc_collections_client(self) -> None:
self._init_async_grpc_channel()
self._aio_grpc_collections_client = grpc.CollectionsStub(self._aio_grpc_channel)

def _init_async_grpc_snapshots_client(self) -> None:
self._init_async_grpc_channel()
self._aio_grpc_snapshots_client = grpc.SnapshotsStub(self._aio_grpc_channel)

def _init_async_grpc_root_client(self) -> None:
self._init_async_grpc_channel()
self._aio_grpc_root_client = grpc.QdrantStub(self._aio_grpc_channel)

@property
def async_grpc_collections(self) -> grpc.CollectionsStub:
"""gRPC client for collections methods
Returns:
An instance of raw gRPC client, generated from Protobuf
"""
if self._aio_grpc_collections_client is None:
self._init_async_grpc_collections_client()
return self._aio_grpc_collections_client

@property
def async_grpc_points(self) -> grpc.PointsStub:
"""gRPC client for points methods
Returns:
An instance of raw gRPC client, generated from Protobuf
"""
if self._aio_grpc_points_client is None:
self._init_async_grpc_points_client()
return self._aio_grpc_points_client

@property
def async_grpc_snapshots(self) -> grpc.SnapshotsStub:
"""gRPC client for snapshots methods
Returns:
An instance of raw gRPC client, generated from Protobuf
"""
warnings.warn(
"async_grpc_snapshots is deprecated and will be removed in a future release. Use `AsyncQdrantRemote.grpc_snapshots` instead.",
DeprecationWarning,
stacklevel=2,
)
if self._aio_grpc_snapshots_client is None:
self._init_async_grpc_snapshots_client()
return self._aio_grpc_snapshots_client

@property
def async_grpc_root(self) -> grpc.QdrantStub:
"""gRPC client for info methods
Returns:
An instance of raw gRPC client, generated from Protobuf
"""
warnings.warn(
"async_grpc_root is deprecated and will be removed in a future release. Use `AsyncQdrantRemote.grpc_root` instead.",
DeprecationWarning,
stacklevel=2,
)
if self._aio_grpc_root_client is None:
self._init_async_grpc_root_client()
return self._aio_grpc_root_client

@property
def grpc_collections(self) -> grpc.CollectionsStub:
Expand Down
76 changes: 0 additions & 76 deletions tests/test_async_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,8 @@
import pytest

import qdrant_client.http.exceptions
from qdrant_client import QdrantClient
from qdrant_client import grpc as qdrant_grpc
from qdrant_client import models
from qdrant_client.async_qdrant_client import AsyncQdrantClient
from qdrant_client.conversions.conversion import payload_to_grpc
from tests.fixtures.payload import one_random_payload_please
from tests.utils import read_version

NUM_VECTORS = 100
Expand All @@ -21,78 +17,6 @@
COLLECTION_NAME = "async_test_collection"


@pytest.mark.asyncio
async def test_async_grpc():
points = (
qdrant_grpc.PointStruct(
id=qdrant_grpc.PointId(num=idx),
vectors=qdrant_grpc.Vectors(
vector=qdrant_grpc.Vector(data=np.random.rand(DIM).tolist())
),
payload=payload_to_grpc(one_random_payload_please(idx)),
)
for idx in range(NUM_VECTORS)
)

client = QdrantClient(prefer_grpc=True, timeout=3)

grpc_collections = client.async_grpc_collections

res = await grpc_collections.List(qdrant_grpc.ListCollectionsRequest(), timeout=1.0)

for collection in res.collections:
print(collection.name)
await grpc_collections.Delete(
qdrant_grpc.DeleteCollection(collection_name=collection.name)
)

await grpc_collections.Create(
qdrant_grpc.CreateCollection(
collection_name=COLLECTION_NAME,
vectors_config=qdrant_grpc.VectorsConfig(
params=qdrant_grpc.VectorParams(size=DIM, distance=qdrant_grpc.Distance.Cosine)
),
)
)

grpc_points = client.async_grpc_points

upload_features = []

# Upload vectors in parallel
for point in points:
upload_features.append(
grpc_points.Upsert(
qdrant_grpc.UpsertPoints(
collection_name=COLLECTION_NAME, wait=True, points=[point]
)
)
)
await asyncio.gather(*upload_features)

queries = [np.random.rand(DIM).tolist() for _ in range(NUM_QUERIES)]

# Make async queries
search_queries = []
for query in queries:
search_query = grpc_points.Search(
qdrant_grpc.SearchPoints(
collection_name=COLLECTION_NAME,
vector=query,
limit=10,
)
)
search_queries.append(search_query)
results = await asyncio.gather(*search_queries) # All queries are running in parallel now

assert len(results) == NUM_QUERIES

for result in results:
assert len(result.result) == 10

client.close()


@pytest.mark.asyncio
@pytest.mark.parametrize("prefer_grpc", [True, False])
async def test_async_qdrant_client(prefer_grpc):
Expand Down
19 changes: 0 additions & 19 deletions tests/test_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1888,24 +1888,6 @@ def test_client_close():
RuntimeError
): # prevent initializing grpc connection, since http connection is closed
_ = client_grpc_do_nothing.get_collection("test")

client_aio_grpc = QdrantClient(prefer_grpc=True, timeout=TIMEOUT)
_ = client_aio_grpc.async_grpc_collections
client_aio_grpc.close()

client_aio_grpc = QdrantClient(prefer_grpc=True, timeout=TIMEOUT)
_ = client_aio_grpc.async_grpc_collections
client_aio_grpc.close(grace=2.0)
with pytest.raises(RuntimeError):
client_aio_grpc._client._init_async_grpc_channel() # prevent reinitializing grpc connection, since
# http connection is closed

client_aio_grpc_do_nothing = QdrantClient(prefer_grpc=True, timeout=TIMEOUT)
client_aio_grpc_do_nothing.close()
with pytest.raises(
RuntimeError
): # prevent initializing grpc connection, since http connection is closed
_ = client_aio_grpc_do_nothing.async_grpc_collections
# endregion grpc

# region local
Expand Down Expand Up @@ -2098,7 +2080,6 @@ async def auth_token_provider():

assert token == ""


@pytest.mark.parametrize("prefer_grpc", [True, False])
def test_read_consistency(prefer_grpc):
fixture_points = generate_fixtures(vectors_sizes=DIM, num=NUM_VECTORS)
Expand Down
3 changes: 0 additions & 3 deletions tools/async_client_generator/client_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,6 @@ def get_async_methods(class_obj: type) -> list[str]:
exclude_methods=[
"__del__",
"migrate",
"async_grpc_collections",
"async_grpc_points",
"async_grpc_root",
],
)

Expand Down
9 changes: 0 additions & 9 deletions tools/async_client_generator/remote_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,6 @@ def get_async_methods(class_obj: type) -> list[str]:
exclude_methods=[
"__del__",
"migrate",
"async_grpc_collections",
"async_grpc_points",
"async_grpc_snapshots",
"async_grpc_root",
"_init_async_grpc_points_client",
"_init_async_grpc_collections_client",
"_init_async_grpc_snapshots_client",
"_init_async_grpc_channel",
"_init_async_grpc_root_client",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.AST:
if hasattr(alias, "name"):
for old_value, new_value in self.import_replace_map.items():
alias.name = alias.name.replace(old_value, new_value)
if alias.name == "get_async_channel":
if alias.name == "get_channel":
alias.name = "get_async_channel"
alias.asname = "get_channel"
node.names = [alias for alias in node.names if alias.name != "get_channel"]

return self.generic_visit(node)

0 comments on commit 3fbdf04

Please sign in to comment.