Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: remove _ClusterBatch #1259

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions integration/test_cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from contextlib import contextmanager
from typing import Generator, List

import weaviate
from weaviate.collections.classes.config import (
Configure,
DataType,
Property,
)


COLLECTION_NAME_PREFIX = "Collection_test_cluster"
NODE_NAME = "node1"
NUM_OBJECT = 10


@contextmanager
def get_weaviate_client(
collection_names: List[str],
) -> Generator[weaviate.WeaviateClient, None, None]:
client = weaviate.connect_to_local()
for collection_name in collection_names:
client.collections.delete(collection_name)
yield client
for collection_name in collection_names:
client.collections.delete(collection_name)
client.close()


def test_rest_nodes_without_data() -> None:
"""get nodes status without data"""
with get_weaviate_client([]) as client:
resp = client.cluster.rest_nodes(output="verbose")
assert len(resp) == 1
assert "gitHash" in resp[0]
assert resp[0]["name"] == NODE_NAME
assert resp[0]["shards"] is None
assert resp[0]["stats"]["objectCount"] == 0
assert resp[0]["stats"]["shardCount"] == 0
assert resp[0]["status"] == "HEALTHY"
assert "version" in resp[0]


def test_rest_nodes_with_data() -> None:
"""get nodes status with data"""
collection_name_1 = f"{COLLECTION_NAME_PREFIX}_rest_nodes_with_data_1"
collection_name_2 = f"{COLLECTION_NAME_PREFIX}_rest_nodes_with_data_2"
uncap_collection_name_1 = collection_name_1[0].lower() + collection_name_1[1:]

with get_weaviate_client([collection_name_1, collection_name_2]) as client:
collection = client.collections.create(
name=collection_name_1,
properties=[Property(name="Name", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.none(),
)
collection.data.insert_many([{"Name": f"name {i}"} for i in range(NUM_OBJECT)])

collection = client.collections.create(
name=collection_name_2,
properties=[Property(name="Name", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.none(),
)
collection.data.insert_many([{"Name": f"name {i}"} for i in range(NUM_OBJECT * 2)])

# server behaviour changed by https://github.com/weaviate/weaviate/pull/4203
server_is_at_least_124 = client._connection._weaviate_version.is_at_least(1, 24, 0)

resp = client.cluster.rest_nodes(output="verbose")
assert len(resp) == 1
assert "gitHash" in resp[0]
assert resp[0]["name"] == NODE_NAME
assert resp[0]["shards"] is not None and len(resp[0]["shards"]) == 2
assert resp[0]["stats"]["objectCount"] == 0 if server_is_at_least_124 else NUM_OBJECT * 3
assert resp[0]["stats"]["shardCount"] == 2
assert resp[0]["status"] == "HEALTHY"
assert "version" in resp[0]

shards = sorted(resp[0]["shards"], key=lambda x: x["class"])
assert shards[0]["class"] == collection_name_1
assert shards[0]["objectCount"] == 0 if server_is_at_least_124 else NUM_OBJECT
assert shards[1]["class"] == collection_name_2
assert shards[1]["objectCount"] == 0 if server_is_at_least_124 else NUM_OBJECT * 2

resp = client.cluster.rest_nodes(collection=collection_name_1, output="verbose")
assert len(resp) == 1
assert "gitHash" in resp[0]
assert resp[0]["name"] == NODE_NAME
assert resp[0]["shards"] is not None and len(resp[0]["shards"]) == 1
assert resp[0]["stats"]["shardCount"] == 1
assert resp[0]["status"] == "HEALTHY"
assert "version" in resp[0]
assert resp[0]["stats"]["objectCount"] == 0 if server_is_at_least_124 else NUM_OBJECT

resp = client.cluster.rest_nodes(uncap_collection_name_1, output="verbose")
assert len(resp) == 1
assert "gitHash" in resp[0]
assert resp[0]["name"] == NODE_NAME
assert resp[0]["shards"] is not None and len(resp[0]["shards"]) == 1
assert resp[0]["stats"]["shardCount"] == 1
assert resp[0]["status"] == "HEALTHY"
assert "version" in resp[0]
assert resp[0]["stats"]["objectCount"] == 0 if server_is_at_least_124 else NUM_OBJECT
33 changes: 5 additions & 28 deletions weaviate/collections/batch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,11 @@
from collections import deque
from copy import copy
from dataclasses import dataclass, field
from typing import Any, Dict, Generic, List, Optional, Set, TypeVar, Union, cast
from typing import Any, Dict, Generic, List, Optional, Set, TypeVar, Union

from pydantic import ValidationError
from typing_extensions import TypeAlias

from httpx import ConnectError

from weaviate.cluster.types import Node
from weaviate.collections.batch.grpc_batch_objects import _BatchGRPC
from weaviate.collections.batch.rest import _BatchREST
from weaviate.collections.classes.batch import (
Expand All @@ -35,12 +32,12 @@
ReferenceInputs,
)
from weaviate.collections.classes.types import WeaviateProperties
from weaviate.collections.cluster import _ClusterAsync
from weaviate.connect import ConnectionV4
from weaviate.event_loop import _EventLoop
from weaviate.exceptions import WeaviateBatchValidationError, EmptyResponseException
from weaviate.exceptions import WeaviateBatchValidationError
from weaviate.logger import logger
from weaviate.types import UUID, VECTORS
from weaviate.util import _decode_json_response_dict
from weaviate.warnings import _Warnings

BatchResponse = List[Dict[str, Any]]
Expand Down Expand Up @@ -183,7 +180,7 @@ def __init__(

self.__results_lock = threading.Lock()

self.__cluster = _ClusterBatch(self.__connection)
self.__cluster = _ClusterAsync(self.__connection)

self.__batching_mode: _BatchMode = batch_mode
self.__max_batch_size: int = 1000
Expand Down Expand Up @@ -360,7 +357,7 @@ def batch_send_wrapper() -> None:
return demonBatchSend

def __dynamic_batching(self) -> None:
status = self.__loop.run_until_complete(self.__cluster.get_nodes_status)
status = self.__loop.run_until_complete(self.__cluster.rest_nodes)
if "batchStats" not in status[0] or "queueLength" not in status[0]["batchStats"]:
# async indexing - just send a lot
self.__batching_mode = _FixedSizeBatching(1000, 10)
Expand Down Expand Up @@ -700,23 +697,3 @@ def __check_bg_thread_alive(self) -> None:
return

raise self.__bg_thread_exception or Exception("Batch thread died unexpectedly")


class _ClusterBatch:
def __init__(self, connection: ConnectionV4):
self._connection = connection

async def get_nodes_status(
self,
) -> List[Node]:
try:
response = await self._connection.get(path="/nodes")
except ConnectError as conn_err:
raise ConnectError("Get nodes status failed due to connection error") from conn_err

response_typed = _decode_json_response_dict(response, "Nodes status")
assert response_typed is not None
nodes = response_typed.get("nodes")
if nodes is None or nodes == []:
raise EmptyResponseException("Nodes status response returned empty")
return cast(List[Node], nodes)
25 changes: 15 additions & 10 deletions weaviate/collections/cluster/cluster.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from weaviate.connect import ConnectionV4


from typing import List, Literal, Optional, Union, overload
from typing import List, Literal, Optional, Union, cast, overload

from weaviate.cluster.types import Node as NodeREST
from weaviate.collections.classes.cluster import Node, Shards, _ConvertFromREST, Stats
from weaviate.connect import ConnectionV4
from weaviate.exceptions import (
EmptyResponseError,
)

from weaviate.util import _capitalize_first_letter, _decode_json_response_dict


Expand Down Expand Up @@ -73,6 +71,17 @@ async def nodes(
`weaviate.EmptyResponseError`
If the response is empty.
"""
nodes = await self.rest_nodes(collection, output)
if output == "verbose":
return _ConvertFromREST.nodes_verbose(nodes)
else:
return _ConvertFromREST.nodes_minimal(nodes)

async def rest_nodes(
self,
collection: Optional[str] = None,
output: Optional[Literal["minimal", "verbose"]] = None,
) -> List[NodeREST]:
Comment on lines +80 to +84
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you are adding this method to _ClusterAsync, you also need to add its synchronous equivalent signature to the sync stubs in collections/cluster/sync.pyi alongside the stubs for the def nodes(...) method, cheers!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done as well

path = "/nodes"
if collection is not None:
path += "/" + _capitalize_first_letter(collection)
Expand All @@ -86,8 +95,4 @@ async def nodes(
nodes = response_typed.get("nodes")
if nodes is None or nodes == []:
raise EmptyResponseError("Nodes status response returned empty")

if output == "verbose":
return _ConvertFromREST.nodes_verbose(nodes)
else:
return _ConvertFromREST.nodes_minimal(nodes)
return cast(List[NodeREST], nodes)
6 changes: 6 additions & 0 deletions weaviate/collections/cluster/sync.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Literal, Optional, overload

from weaviate.cluster.types import Node as NodeREST
from weaviate.collections.classes.cluster import Node, Shards, Stats
from weaviate.collections.cluster.cluster import _ClusterBase

Expand All @@ -25,3 +26,8 @@ class _Cluster(_ClusterBase):
*,
output: Literal["verbose"],
) -> List[Node[Shards, Stats]]: ...
def rest_nodes(
self,
collection: Optional[str] = None,
output: Optional[Literal["minimal", "verbose"]] = None,
) -> List[NodeREST]: ...